]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/cpp/src/arrow/compute/function.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / function.cc
CommitLineData
1d09f67e
TL
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/compute/function.h"
19
20#include <cstddef>
21#include <memory>
22#include <sstream>
23
24#include "arrow/compute/api_scalar.h"
25#include "arrow/compute/cast.h"
26#include "arrow/compute/exec.h"
27#include "arrow/compute/exec_internal.h"
28#include "arrow/compute/function_internal.h"
29#include "arrow/compute/kernels/common.h"
30#include "arrow/compute/registry.h"
31#include "arrow/datum.h"
32#include "arrow/util/cpu_info.h"
33#include "arrow/util/logging.h"
34
35namespace arrow {
36
37using internal::checked_cast;
38
39namespace compute {
40Result<std::shared_ptr<Buffer>> FunctionOptionsType::Serialize(
41 const FunctionOptions&) const {
42 return Status::NotImplemented("Serialize for ", type_name());
43}
44
45Result<std::unique_ptr<FunctionOptions>> FunctionOptionsType::Deserialize(
46 const Buffer& buffer) const {
47 return Status::NotImplemented("Deserialize for ", type_name());
48}
49
50std::string FunctionOptions::ToString() const { return options_type()->Stringify(*this); }
51
52bool FunctionOptions::Equals(const FunctionOptions& other) const {
53 if (this == &other) return true;
54 if (options_type() != other.options_type()) return false;
55 return options_type()->Compare(*this, other);
56}
57
58std::unique_ptr<FunctionOptions> FunctionOptions::Copy() const {
59 return options_type()->Copy(*this);
60}
61
62Result<std::shared_ptr<Buffer>> FunctionOptions::Serialize() const {
63 return options_type()->Serialize(*this);
64}
65
66Result<std::unique_ptr<FunctionOptions>> FunctionOptions::Deserialize(
67 const std::string& type_name, const Buffer& buffer) {
68 ARROW_ASSIGN_OR_RAISE(auto options,
69 GetFunctionRegistry()->GetFunctionOptionsType(type_name));
70 return options->Deserialize(buffer);
71}
72
73void PrintTo(const FunctionOptions& options, std::ostream* os) {
74 *os << options.ToString();
75}
76
77static const FunctionDoc kEmptyFunctionDoc{};
78
79const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; }
80
81static Status CheckArityImpl(const Function* function, int passed_num_args,
82 const char* passed_num_args_label) {
83 if (function->arity().is_varargs && passed_num_args < function->arity().num_args) {
84 return Status::Invalid("VarArgs function ", function->name(), " needs at least ",
85 function->arity().num_args, " arguments but ",
86 passed_num_args_label, " only ", passed_num_args);
87 }
88
89 if (!function->arity().is_varargs && passed_num_args != function->arity().num_args) {
90 return Status::Invalid("Function ", function->name(), " accepts ",
91 function->arity().num_args, " arguments but ",
92 passed_num_args_label, " ", passed_num_args);
93 }
94
95 return Status::OK();
96}
97
98Status Function::CheckArity(const std::vector<InputType>& in_types) const {
99 return CheckArityImpl(this, static_cast<int>(in_types.size()), "kernel accepts");
100}
101
102Status Function::CheckArity(const std::vector<ValueDescr>& descrs) const {
103 return CheckArityImpl(this, static_cast<int>(descrs.size()),
104 "attempted to look up kernel(s) with");
105}
106
107namespace detail {
108
109Status NoMatchingKernel(const Function* func, const std::vector<ValueDescr>& descrs) {
110 return Status::NotImplemented("Function ", func->name(),
111 " has no kernel matching input types ",
112 ValueDescr::ToString(descrs));
113}
114
115template <typename KernelType>
116const KernelType* DispatchExactImpl(const std::vector<KernelType*>& kernels,
117 const std::vector<ValueDescr>& values) {
118 const KernelType* kernel_matches[SimdLevel::MAX] = {nullptr};
119
120 // Validate arity
121 for (const auto& kernel : kernels) {
122 if (kernel->signature->MatchesInputs(values)) {
123 kernel_matches[kernel->simd_level] = kernel;
124 }
125 }
126
127 // Dispatch as the CPU feature
128#if defined(ARROW_HAVE_RUNTIME_AVX512) || defined(ARROW_HAVE_RUNTIME_AVX2)
129 auto cpu_info = arrow::internal::CpuInfo::GetInstance();
130#endif
131#if defined(ARROW_HAVE_RUNTIME_AVX512)
132 if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) {
133 if (kernel_matches[SimdLevel::AVX512]) {
134 return kernel_matches[SimdLevel::AVX512];
135 }
136 }
137#endif
138#if defined(ARROW_HAVE_RUNTIME_AVX2)
139 if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
140 if (kernel_matches[SimdLevel::AVX2]) {
141 return kernel_matches[SimdLevel::AVX2];
142 }
143 }
144#endif
145 if (kernel_matches[SimdLevel::NONE]) {
146 return kernel_matches[SimdLevel::NONE];
147 }
148
149 return nullptr;
150}
151
152const Kernel* DispatchExactImpl(const Function* func,
153 const std::vector<ValueDescr>& values) {
154 if (func->kind() == Function::SCALAR) {
155 return DispatchExactImpl(checked_cast<const ScalarFunction*>(func)->kernels(),
156 values);
157 }
158
159 if (func->kind() == Function::VECTOR) {
160 return DispatchExactImpl(checked_cast<const VectorFunction*>(func)->kernels(),
161 values);
162 }
163
164 if (func->kind() == Function::SCALAR_AGGREGATE) {
165 return DispatchExactImpl(
166 checked_cast<const ScalarAggregateFunction*>(func)->kernels(), values);
167 }
168
169 if (func->kind() == Function::HASH_AGGREGATE) {
170 return DispatchExactImpl(checked_cast<const HashAggregateFunction*>(func)->kernels(),
171 values);
172 }
173
174 return nullptr;
175}
176
177} // namespace detail
178
179Result<const Kernel*> Function::DispatchExact(
180 const std::vector<ValueDescr>& values) const {
181 if (kind_ == Function::META) {
182 return Status::NotImplemented("Dispatch for a MetaFunction's Kernels");
183 }
184 RETURN_NOT_OK(CheckArity(values));
185
186 if (auto kernel = detail::DispatchExactImpl(this, values)) {
187 return kernel;
188 }
189 return detail::NoMatchingKernel(this, values);
190}
191
192Result<const Kernel*> Function::DispatchBest(std::vector<ValueDescr>* values) const {
193 // TODO(ARROW-11508) permit generic conversions here
194 return DispatchExact(*values);
195}
196
197Result<Datum> Function::Execute(const std::vector<Datum>& args,
198 const FunctionOptions* options, ExecContext* ctx) const {
199 if (options == nullptr) {
200 options = default_options();
201 }
202 if (ctx == nullptr) {
203 ExecContext default_ctx;
204 return Execute(args, options, &default_ctx);
205 }
206
207 // type-check Datum arguments here. Really we'd like to avoid this as much as
208 // possible
209 RETURN_NOT_OK(detail::CheckAllValues(args));
210 std::vector<ValueDescr> inputs(args.size());
211 for (size_t i = 0; i != args.size(); ++i) {
212 inputs[i] = args[i].descr();
213 }
214
215 ARROW_ASSIGN_OR_RAISE(auto kernel, DispatchBest(&inputs));
216 ARROW_ASSIGN_OR_RAISE(auto implicitly_cast_args, Cast(args, inputs, ctx));
217
218 std::unique_ptr<KernelState> state;
219
220 KernelContext kernel_ctx{ctx};
221 if (kernel->init) {
222 ARROW_ASSIGN_OR_RAISE(state, kernel->init(&kernel_ctx, {kernel, inputs, options}));
223 kernel_ctx.SetState(state.get());
224 }
225
226 std::unique_ptr<detail::KernelExecutor> executor;
227 if (kind() == Function::SCALAR) {
228 executor = detail::KernelExecutor::MakeScalar();
229 } else if (kind() == Function::VECTOR) {
230 executor = detail::KernelExecutor::MakeVector();
231 } else if (kind() == Function::SCALAR_AGGREGATE) {
232 executor = detail::KernelExecutor::MakeScalarAggregate();
233 } else {
234 return Status::NotImplemented("Direct execution of HASH_AGGREGATE functions");
235 }
236 RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options}));
237
238 detail::DatumAccumulator listener;
239 RETURN_NOT_OK(executor->Execute(implicitly_cast_args, &listener));
240 const auto out = executor->WrapResults(implicitly_cast_args, listener.values());
241#ifndef NDEBUG
242 DCHECK_OK(executor->CheckResultType(out, name_.c_str()));
243#endif
244 return out;
245}
246
247Status Function::Validate() const {
248 if (!doc_->summary.empty()) {
249 // Documentation given, check its contents
250 int arg_count = static_cast<int>(doc_->arg_names.size());
251 if (arg_count == arity_.num_args) {
252 return Status::OK();
253 }
254 if (arity_.is_varargs && arg_count == arity_.num_args + 1) {
255 return Status::OK();
256 }
257 return Status::Invalid(
258 "In function '", name_,
259 "': ", "number of argument names for function documentation != function arity");
260 }
261 return Status::OK();
262}
263
264Status ScalarFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type,
265 ArrayKernelExec exec, KernelInit init) {
266 RETURN_NOT_OK(CheckArity(in_types));
267
268 if (arity_.is_varargs && in_types.size() != 1) {
269 return Status::Invalid("VarArgs signatures must have exactly one input type");
270 }
271 auto sig =
272 KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs);
273 kernels_.emplace_back(std::move(sig), exec, init);
274 return Status::OK();
275}
276
277Status ScalarFunction::AddKernel(ScalarKernel kernel) {
278 RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
279 if (arity_.is_varargs && !kernel.signature->is_varargs()) {
280 return Status::Invalid("Function accepts varargs but kernel signature does not");
281 }
282 kernels_.emplace_back(std::move(kernel));
283 return Status::OK();
284}
285
286Status VectorFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type,
287 ArrayKernelExec exec, KernelInit init) {
288 RETURN_NOT_OK(CheckArity(in_types));
289
290 if (arity_.is_varargs && in_types.size() != 1) {
291 return Status::Invalid("VarArgs signatures must have exactly one input type");
292 }
293 auto sig =
294 KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs);
295 kernels_.emplace_back(std::move(sig), exec, init);
296 return Status::OK();
297}
298
299Status VectorFunction::AddKernel(VectorKernel kernel) {
300 RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
301 if (arity_.is_varargs && !kernel.signature->is_varargs()) {
302 return Status::Invalid("Function accepts varargs but kernel signature does not");
303 }
304 kernels_.emplace_back(std::move(kernel));
305 return Status::OK();
306}
307
308Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) {
309 RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
310 if (arity_.is_varargs && !kernel.signature->is_varargs()) {
311 return Status::Invalid("Function accepts varargs but kernel signature does not");
312 }
313 kernels_.emplace_back(std::move(kernel));
314 return Status::OK();
315}
316
317Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) {
318 RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
319 if (arity_.is_varargs && !kernel.signature->is_varargs()) {
320 return Status::Invalid("Function accepts varargs but kernel signature does not");
321 }
322 kernels_.emplace_back(std::move(kernel));
323 return Status::OK();
324}
325
326Result<Datum> MetaFunction::Execute(const std::vector<Datum>& args,
327 const FunctionOptions* options,
328 ExecContext* ctx) const {
329 RETURN_NOT_OK(
330 CheckArityImpl(this, static_cast<int>(args.size()), "attempted to Execute with"));
331
332 if (options == nullptr) {
333 options = default_options();
334 }
335 return ExecuteImpl(args, options, ctx);
336}
337
338} // namespace compute
339} // namespace arrow