]>
git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/registry.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include "arrow/compute/registry.h"
23 #include <unordered_map>
26 #include "arrow/compute/function.h"
27 #include "arrow/compute/function_internal.h"
28 #include "arrow/compute/registry_internal.h"
29 #include "arrow/status.h"
30 #include "arrow/util/logging.h"
35 class FunctionRegistry::FunctionRegistryImpl
{
37 Status
AddFunction(std::shared_ptr
<Function
> function
, bool allow_overwrite
) {
38 RETURN_NOT_OK(function
->Validate());
40 std::lock_guard
<std::mutex
> mutation_guard(lock_
);
42 const std::string
& name
= function
->name();
43 auto it
= name_to_function_
.find(name
);
44 if (it
!= name_to_function_
.end() && !allow_overwrite
) {
45 return Status::KeyError("Already have a function registered with name: ", name
);
47 name_to_function_
[name
] = std::move(function
);
51 Status
AddAlias(const std::string
& target_name
, const std::string
& source_name
) {
52 std::lock_guard
<std::mutex
> mutation_guard(lock_
);
54 auto it
= name_to_function_
.find(source_name
);
55 if (it
== name_to_function_
.end()) {
56 return Status::KeyError("No function registered with name: ", source_name
);
58 name_to_function_
[target_name
] = it
->second
;
62 Status
AddFunctionOptionsType(const FunctionOptionsType
* options_type
,
63 bool allow_overwrite
= false) {
64 std::lock_guard
<std::mutex
> mutation_guard(lock_
);
66 const std::string name
= options_type
->type_name();
67 auto it
= name_to_options_type_
.find(name
);
68 if (it
!= name_to_options_type_
.end() && !allow_overwrite
) {
69 return Status::KeyError(
70 "Already have a function options type registered with name: ", name
);
72 name_to_options_type_
[name
] = options_type
;
76 Result
<std::shared_ptr
<Function
>> GetFunction(const std::string
& name
) const {
77 auto it
= name_to_function_
.find(name
);
78 if (it
== name_to_function_
.end()) {
79 return Status::KeyError("No function registered with name: ", name
);
84 std::vector
<std::string
> GetFunctionNames() const {
85 std::vector
<std::string
> results
;
86 for (auto it
: name_to_function_
) {
87 results
.push_back(it
.first
);
89 std::sort(results
.begin(), results
.end());
93 Result
<const FunctionOptionsType
*> GetFunctionOptionsType(
94 const std::string
& name
) const {
95 auto it
= name_to_options_type_
.find(name
);
96 if (it
== name_to_options_type_
.end()) {
97 return Status::KeyError("No function options type registered with name: ", name
);
102 int num_functions() const { return static_cast<int>(name_to_function_
.size()); }
106 std::unordered_map
<std::string
, std::shared_ptr
<Function
>> name_to_function_
;
107 std::unordered_map
<std::string
, const FunctionOptionsType
*> name_to_options_type_
;
110 std::unique_ptr
<FunctionRegistry
> FunctionRegistry::Make() {
111 return std::unique_ptr
<FunctionRegistry
>(new FunctionRegistry());
114 FunctionRegistry::FunctionRegistry() { impl_
.reset(new FunctionRegistryImpl()); }
116 FunctionRegistry::~FunctionRegistry() {}
118 Status
FunctionRegistry::AddFunction(std::shared_ptr
<Function
> function
,
119 bool allow_overwrite
) {
120 return impl_
->AddFunction(std::move(function
), allow_overwrite
);
123 Status
FunctionRegistry::AddAlias(const std::string
& target_name
,
124 const std::string
& source_name
) {
125 return impl_
->AddAlias(target_name
, source_name
);
128 Status
FunctionRegistry::AddFunctionOptionsType(const FunctionOptionsType
* options_type
,
129 bool allow_overwrite
) {
130 return impl_
->AddFunctionOptionsType(options_type
, allow_overwrite
);
133 Result
<std::shared_ptr
<Function
>> FunctionRegistry::GetFunction(
134 const std::string
& name
) const {
135 return impl_
->GetFunction(name
);
138 std::vector
<std::string
> FunctionRegistry::GetFunctionNames() const {
139 return impl_
->GetFunctionNames();
142 Result
<const FunctionOptionsType
*> FunctionRegistry::GetFunctionOptionsType(
143 const std::string
& name
) const {
144 return impl_
->GetFunctionOptionsType(name
);
147 int FunctionRegistry::num_functions() const { return impl_
->num_functions(); }
151 static std::unique_ptr
<FunctionRegistry
> CreateBuiltInRegistry() {
152 auto registry
= FunctionRegistry::Make();
155 RegisterScalarArithmetic(registry
.get());
156 RegisterScalarBoolean(registry
.get());
157 RegisterScalarCast(registry
.get());
158 RegisterScalarComparison(registry
.get());
159 RegisterScalarIfElse(registry
.get());
160 RegisterScalarNested(registry
.get());
161 RegisterScalarSetLookup(registry
.get());
162 RegisterScalarStringAscii(registry
.get());
163 RegisterScalarTemporalBinary(registry
.get());
164 RegisterScalarTemporalUnary(registry
.get());
165 RegisterScalarValidity(registry
.get());
167 RegisterScalarOptions(registry
.get());
170 RegisterVectorArraySort(registry
.get());
171 RegisterVectorHash(registry
.get());
172 RegisterVectorNested(registry
.get());
173 RegisterVectorReplace(registry
.get());
174 RegisterVectorSelection(registry
.get());
175 RegisterVectorSort(registry
.get());
177 RegisterVectorOptions(registry
.get());
179 // Aggregate functions
180 RegisterHashAggregateBasic(registry
.get());
181 RegisterScalarAggregateBasic(registry
.get());
182 RegisterScalarAggregateMode(registry
.get());
183 RegisterScalarAggregateQuantile(registry
.get());
184 RegisterScalarAggregateTDigest(registry
.get());
185 RegisterScalarAggregateVariance(registry
.get());
187 RegisterAggregateOptions(registry
.get());
192 } // namespace internal
194 FunctionRegistry
* GetFunctionRegistry() {
195 static auto g_registry
= internal::CreateBuiltInRegistry();
196 return g_registry
.get();
199 } // namespace compute