]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/compute/function.h" | |
19 | ||
20 | #include <cstddef> | |
21 | #include <memory> | |
22 | #include <sstream> | |
23 | ||
24 | #include "arrow/compute/api_scalar.h" | |
25 | #include "arrow/compute/cast.h" | |
26 | #include "arrow/compute/exec.h" | |
27 | #include "arrow/compute/exec_internal.h" | |
28 | #include "arrow/compute/function_internal.h" | |
29 | #include "arrow/compute/kernels/common.h" | |
30 | #include "arrow/compute/registry.h" | |
31 | #include "arrow/datum.h" | |
32 | #include "arrow/util/cpu_info.h" | |
33 | #include "arrow/util/logging.h" | |
34 | ||
35 | namespace arrow { | |
36 | ||
37 | using internal::checked_cast; | |
38 | ||
39 | namespace compute { | |
40 | Result<std::shared_ptr<Buffer>> FunctionOptionsType::Serialize( | |
41 | const FunctionOptions&) const { | |
42 | return Status::NotImplemented("Serialize for ", type_name()); | |
43 | } | |
44 | ||
45 | Result<std::unique_ptr<FunctionOptions>> FunctionOptionsType::Deserialize( | |
46 | const Buffer& buffer) const { | |
47 | return Status::NotImplemented("Deserialize for ", type_name()); | |
48 | } | |
49 | ||
50 | std::string FunctionOptions::ToString() const { return options_type()->Stringify(*this); } | |
51 | ||
52 | bool FunctionOptions::Equals(const FunctionOptions& other) const { | |
53 | if (this == &other) return true; | |
54 | if (options_type() != other.options_type()) return false; | |
55 | return options_type()->Compare(*this, other); | |
56 | } | |
57 | ||
58 | std::unique_ptr<FunctionOptions> FunctionOptions::Copy() const { | |
59 | return options_type()->Copy(*this); | |
60 | } | |
61 | ||
62 | Result<std::shared_ptr<Buffer>> FunctionOptions::Serialize() const { | |
63 | return options_type()->Serialize(*this); | |
64 | } | |
65 | ||
66 | Result<std::unique_ptr<FunctionOptions>> FunctionOptions::Deserialize( | |
67 | const std::string& type_name, const Buffer& buffer) { | |
68 | ARROW_ASSIGN_OR_RAISE(auto options, | |
69 | GetFunctionRegistry()->GetFunctionOptionsType(type_name)); | |
70 | return options->Deserialize(buffer); | |
71 | } | |
72 | ||
73 | void PrintTo(const FunctionOptions& options, std::ostream* os) { | |
74 | *os << options.ToString(); | |
75 | } | |
76 | ||
77 | static const FunctionDoc kEmptyFunctionDoc{}; | |
78 | ||
79 | const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; } | |
80 | ||
81 | static Status CheckArityImpl(const Function* function, int passed_num_args, | |
82 | const char* passed_num_args_label) { | |
83 | if (function->arity().is_varargs && passed_num_args < function->arity().num_args) { | |
84 | return Status::Invalid("VarArgs function ", function->name(), " needs at least ", | |
85 | function->arity().num_args, " arguments but ", | |
86 | passed_num_args_label, " only ", passed_num_args); | |
87 | } | |
88 | ||
89 | if (!function->arity().is_varargs && passed_num_args != function->arity().num_args) { | |
90 | return Status::Invalid("Function ", function->name(), " accepts ", | |
91 | function->arity().num_args, " arguments but ", | |
92 | passed_num_args_label, " ", passed_num_args); | |
93 | } | |
94 | ||
95 | return Status::OK(); | |
96 | } | |
97 | ||
98 | Status Function::CheckArity(const std::vector<InputType>& in_types) const { | |
99 | return CheckArityImpl(this, static_cast<int>(in_types.size()), "kernel accepts"); | |
100 | } | |
101 | ||
102 | Status Function::CheckArity(const std::vector<ValueDescr>& descrs) const { | |
103 | return CheckArityImpl(this, static_cast<int>(descrs.size()), | |
104 | "attempted to look up kernel(s) with"); | |
105 | } | |
106 | ||
107 | namespace detail { | |
108 | ||
109 | Status NoMatchingKernel(const Function* func, const std::vector<ValueDescr>& descrs) { | |
110 | return Status::NotImplemented("Function ", func->name(), | |
111 | " has no kernel matching input types ", | |
112 | ValueDescr::ToString(descrs)); | |
113 | } | |
114 | ||
115 | template <typename KernelType> | |
116 | const KernelType* DispatchExactImpl(const std::vector<KernelType*>& kernels, | |
117 | const std::vector<ValueDescr>& values) { | |
118 | const KernelType* kernel_matches[SimdLevel::MAX] = {nullptr}; | |
119 | ||
120 | // Validate arity | |
121 | for (const auto& kernel : kernels) { | |
122 | if (kernel->signature->MatchesInputs(values)) { | |
123 | kernel_matches[kernel->simd_level] = kernel; | |
124 | } | |
125 | } | |
126 | ||
127 | // Dispatch as the CPU feature | |
128 | #if defined(ARROW_HAVE_RUNTIME_AVX512) || defined(ARROW_HAVE_RUNTIME_AVX2) | |
129 | auto cpu_info = arrow::internal::CpuInfo::GetInstance(); | |
130 | #endif | |
131 | #if defined(ARROW_HAVE_RUNTIME_AVX512) | |
132 | if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { | |
133 | if (kernel_matches[SimdLevel::AVX512]) { | |
134 | return kernel_matches[SimdLevel::AVX512]; | |
135 | } | |
136 | } | |
137 | #endif | |
138 | #if defined(ARROW_HAVE_RUNTIME_AVX2) | |
139 | if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { | |
140 | if (kernel_matches[SimdLevel::AVX2]) { | |
141 | return kernel_matches[SimdLevel::AVX2]; | |
142 | } | |
143 | } | |
144 | #endif | |
145 | if (kernel_matches[SimdLevel::NONE]) { | |
146 | return kernel_matches[SimdLevel::NONE]; | |
147 | } | |
148 | ||
149 | return nullptr; | |
150 | } | |
151 | ||
152 | const Kernel* DispatchExactImpl(const Function* func, | |
153 | const std::vector<ValueDescr>& values) { | |
154 | if (func->kind() == Function::SCALAR) { | |
155 | return DispatchExactImpl(checked_cast<const ScalarFunction*>(func)->kernels(), | |
156 | values); | |
157 | } | |
158 | ||
159 | if (func->kind() == Function::VECTOR) { | |
160 | return DispatchExactImpl(checked_cast<const VectorFunction*>(func)->kernels(), | |
161 | values); | |
162 | } | |
163 | ||
164 | if (func->kind() == Function::SCALAR_AGGREGATE) { | |
165 | return DispatchExactImpl( | |
166 | checked_cast<const ScalarAggregateFunction*>(func)->kernels(), values); | |
167 | } | |
168 | ||
169 | if (func->kind() == Function::HASH_AGGREGATE) { | |
170 | return DispatchExactImpl(checked_cast<const HashAggregateFunction*>(func)->kernels(), | |
171 | values); | |
172 | } | |
173 | ||
174 | return nullptr; | |
175 | } | |
176 | ||
177 | } // namespace detail | |
178 | ||
179 | Result<const Kernel*> Function::DispatchExact( | |
180 | const std::vector<ValueDescr>& values) const { | |
181 | if (kind_ == Function::META) { | |
182 | return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); | |
183 | } | |
184 | RETURN_NOT_OK(CheckArity(values)); | |
185 | ||
186 | if (auto kernel = detail::DispatchExactImpl(this, values)) { | |
187 | return kernel; | |
188 | } | |
189 | return detail::NoMatchingKernel(this, values); | |
190 | } | |
191 | ||
192 | Result<const Kernel*> Function::DispatchBest(std::vector<ValueDescr>* values) const { | |
193 | // TODO(ARROW-11508) permit generic conversions here | |
194 | return DispatchExact(*values); | |
195 | } | |
196 | ||
197 | Result<Datum> Function::Execute(const std::vector<Datum>& args, | |
198 | const FunctionOptions* options, ExecContext* ctx) const { | |
199 | if (options == nullptr) { | |
200 | options = default_options(); | |
201 | } | |
202 | if (ctx == nullptr) { | |
203 | ExecContext default_ctx; | |
204 | return Execute(args, options, &default_ctx); | |
205 | } | |
206 | ||
207 | // type-check Datum arguments here. Really we'd like to avoid this as much as | |
208 | // possible | |
209 | RETURN_NOT_OK(detail::CheckAllValues(args)); | |
210 | std::vector<ValueDescr> inputs(args.size()); | |
211 | for (size_t i = 0; i != args.size(); ++i) { | |
212 | inputs[i] = args[i].descr(); | |
213 | } | |
214 | ||
215 | ARROW_ASSIGN_OR_RAISE(auto kernel, DispatchBest(&inputs)); | |
216 | ARROW_ASSIGN_OR_RAISE(auto implicitly_cast_args, Cast(args, inputs, ctx)); | |
217 | ||
218 | std::unique_ptr<KernelState> state; | |
219 | ||
220 | KernelContext kernel_ctx{ctx}; | |
221 | if (kernel->init) { | |
222 | ARROW_ASSIGN_OR_RAISE(state, kernel->init(&kernel_ctx, {kernel, inputs, options})); | |
223 | kernel_ctx.SetState(state.get()); | |
224 | } | |
225 | ||
226 | std::unique_ptr<detail::KernelExecutor> executor; | |
227 | if (kind() == Function::SCALAR) { | |
228 | executor = detail::KernelExecutor::MakeScalar(); | |
229 | } else if (kind() == Function::VECTOR) { | |
230 | executor = detail::KernelExecutor::MakeVector(); | |
231 | } else if (kind() == Function::SCALAR_AGGREGATE) { | |
232 | executor = detail::KernelExecutor::MakeScalarAggregate(); | |
233 | } else { | |
234 | return Status::NotImplemented("Direct execution of HASH_AGGREGATE functions"); | |
235 | } | |
236 | RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options})); | |
237 | ||
238 | detail::DatumAccumulator listener; | |
239 | RETURN_NOT_OK(executor->Execute(implicitly_cast_args, &listener)); | |
240 | const auto out = executor->WrapResults(implicitly_cast_args, listener.values()); | |
241 | #ifndef NDEBUG | |
242 | DCHECK_OK(executor->CheckResultType(out, name_.c_str())); | |
243 | #endif | |
244 | return out; | |
245 | } | |
246 | ||
247 | Status Function::Validate() const { | |
248 | if (!doc_->summary.empty()) { | |
249 | // Documentation given, check its contents | |
250 | int arg_count = static_cast<int>(doc_->arg_names.size()); | |
251 | if (arg_count == arity_.num_args) { | |
252 | return Status::OK(); | |
253 | } | |
254 | if (arity_.is_varargs && arg_count == arity_.num_args + 1) { | |
255 | return Status::OK(); | |
256 | } | |
257 | return Status::Invalid( | |
258 | "In function '", name_, | |
259 | "': ", "number of argument names for function documentation != function arity"); | |
260 | } | |
261 | return Status::OK(); | |
262 | } | |
263 | ||
264 | Status ScalarFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type, | |
265 | ArrayKernelExec exec, KernelInit init) { | |
266 | RETURN_NOT_OK(CheckArity(in_types)); | |
267 | ||
268 | if (arity_.is_varargs && in_types.size() != 1) { | |
269 | return Status::Invalid("VarArgs signatures must have exactly one input type"); | |
270 | } | |
271 | auto sig = | |
272 | KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); | |
273 | kernels_.emplace_back(std::move(sig), exec, init); | |
274 | return Status::OK(); | |
275 | } | |
276 | ||
277 | Status ScalarFunction::AddKernel(ScalarKernel kernel) { | |
278 | RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); | |
279 | if (arity_.is_varargs && !kernel.signature->is_varargs()) { | |
280 | return Status::Invalid("Function accepts varargs but kernel signature does not"); | |
281 | } | |
282 | kernels_.emplace_back(std::move(kernel)); | |
283 | return Status::OK(); | |
284 | } | |
285 | ||
286 | Status VectorFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type, | |
287 | ArrayKernelExec exec, KernelInit init) { | |
288 | RETURN_NOT_OK(CheckArity(in_types)); | |
289 | ||
290 | if (arity_.is_varargs && in_types.size() != 1) { | |
291 | return Status::Invalid("VarArgs signatures must have exactly one input type"); | |
292 | } | |
293 | auto sig = | |
294 | KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); | |
295 | kernels_.emplace_back(std::move(sig), exec, init); | |
296 | return Status::OK(); | |
297 | } | |
298 | ||
299 | Status VectorFunction::AddKernel(VectorKernel kernel) { | |
300 | RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); | |
301 | if (arity_.is_varargs && !kernel.signature->is_varargs()) { | |
302 | return Status::Invalid("Function accepts varargs but kernel signature does not"); | |
303 | } | |
304 | kernels_.emplace_back(std::move(kernel)); | |
305 | return Status::OK(); | |
306 | } | |
307 | ||
308 | Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { | |
309 | RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); | |
310 | if (arity_.is_varargs && !kernel.signature->is_varargs()) { | |
311 | return Status::Invalid("Function accepts varargs but kernel signature does not"); | |
312 | } | |
313 | kernels_.emplace_back(std::move(kernel)); | |
314 | return Status::OK(); | |
315 | } | |
316 | ||
317 | Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) { | |
318 | RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); | |
319 | if (arity_.is_varargs && !kernel.signature->is_varargs()) { | |
320 | return Status::Invalid("Function accepts varargs but kernel signature does not"); | |
321 | } | |
322 | kernels_.emplace_back(std::move(kernel)); | |
323 | return Status::OK(); | |
324 | } | |
325 | ||
326 | Result<Datum> MetaFunction::Execute(const std::vector<Datum>& args, | |
327 | const FunctionOptions* options, | |
328 | ExecContext* ctx) const { | |
329 | RETURN_NOT_OK( | |
330 | CheckArityImpl(this, static_cast<int>(args.size()), "attempted to Execute with")); | |
331 | ||
332 | if (options == nullptr) { | |
333 | options = default_options(); | |
334 | } | |
335 | return ExecuteImpl(args, options, ctx); | |
336 | } | |
337 | ||
338 | } // namespace compute | |
339 | } // namespace arrow |