]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/registry.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / registry.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/compute/registry.h"
19
20 #include <algorithm>
21 #include <memory>
22 #include <mutex>
23 #include <unordered_map>
24 #include <utility>
25
26 #include "arrow/compute/function.h"
27 #include "arrow/compute/function_internal.h"
28 #include "arrow/compute/registry_internal.h"
29 #include "arrow/status.h"
30 #include "arrow/util/logging.h"
31
32 namespace arrow {
33 namespace compute {
34
35 class FunctionRegistry::FunctionRegistryImpl {
36 public:
37 Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
38 RETURN_NOT_OK(function->Validate());
39
40 std::lock_guard<std::mutex> mutation_guard(lock_);
41
42 const std::string& name = function->name();
43 auto it = name_to_function_.find(name);
44 if (it != name_to_function_.end() && !allow_overwrite) {
45 return Status::KeyError("Already have a function registered with name: ", name);
46 }
47 name_to_function_[name] = std::move(function);
48 return Status::OK();
49 }
50
51 Status AddAlias(const std::string& target_name, const std::string& source_name) {
52 std::lock_guard<std::mutex> mutation_guard(lock_);
53
54 auto it = name_to_function_.find(source_name);
55 if (it == name_to_function_.end()) {
56 return Status::KeyError("No function registered with name: ", source_name);
57 }
58 name_to_function_[target_name] = it->second;
59 return Status::OK();
60 }
61
62 Status AddFunctionOptionsType(const FunctionOptionsType* options_type,
63 bool allow_overwrite = false) {
64 std::lock_guard<std::mutex> mutation_guard(lock_);
65
66 const std::string name = options_type->type_name();
67 auto it = name_to_options_type_.find(name);
68 if (it != name_to_options_type_.end() && !allow_overwrite) {
69 return Status::KeyError(
70 "Already have a function options type registered with name: ", name);
71 }
72 name_to_options_type_[name] = options_type;
73 return Status::OK();
74 }
75
76 Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const {
77 auto it = name_to_function_.find(name);
78 if (it == name_to_function_.end()) {
79 return Status::KeyError("No function registered with name: ", name);
80 }
81 return it->second;
82 }
83
84 std::vector<std::string> GetFunctionNames() const {
85 std::vector<std::string> results;
86 for (auto it : name_to_function_) {
87 results.push_back(it.first);
88 }
89 std::sort(results.begin(), results.end());
90 return results;
91 }
92
93 Result<const FunctionOptionsType*> GetFunctionOptionsType(
94 const std::string& name) const {
95 auto it = name_to_options_type_.find(name);
96 if (it == name_to_options_type_.end()) {
97 return Status::KeyError("No function options type registered with name: ", name);
98 }
99 return it->second;
100 }
101
102 int num_functions() const { return static_cast<int>(name_to_function_.size()); }
103
104 private:
105 std::mutex lock_;
106 std::unordered_map<std::string, std::shared_ptr<Function>> name_to_function_;
107 std::unordered_map<std::string, const FunctionOptionsType*> name_to_options_type_;
108 };
109
110 std::unique_ptr<FunctionRegistry> FunctionRegistry::Make() {
111 return std::unique_ptr<FunctionRegistry>(new FunctionRegistry());
112 }
113
114 FunctionRegistry::FunctionRegistry() { impl_.reset(new FunctionRegistryImpl()); }
115
116 FunctionRegistry::~FunctionRegistry() {}
117
118 Status FunctionRegistry::AddFunction(std::shared_ptr<Function> function,
119 bool allow_overwrite) {
120 return impl_->AddFunction(std::move(function), allow_overwrite);
121 }
122
123 Status FunctionRegistry::AddAlias(const std::string& target_name,
124 const std::string& source_name) {
125 return impl_->AddAlias(target_name, source_name);
126 }
127
128 Status FunctionRegistry::AddFunctionOptionsType(const FunctionOptionsType* options_type,
129 bool allow_overwrite) {
130 return impl_->AddFunctionOptionsType(options_type, allow_overwrite);
131 }
132
133 Result<std::shared_ptr<Function>> FunctionRegistry::GetFunction(
134 const std::string& name) const {
135 return impl_->GetFunction(name);
136 }
137
138 std::vector<std::string> FunctionRegistry::GetFunctionNames() const {
139 return impl_->GetFunctionNames();
140 }
141
142 Result<const FunctionOptionsType*> FunctionRegistry::GetFunctionOptionsType(
143 const std::string& name) const {
144 return impl_->GetFunctionOptionsType(name);
145 }
146
147 int FunctionRegistry::num_functions() const { return impl_->num_functions(); }
148
149 namespace internal {
150
151 static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {
152 auto registry = FunctionRegistry::Make();
153
154 // Scalar functions
155 RegisterScalarArithmetic(registry.get());
156 RegisterScalarBoolean(registry.get());
157 RegisterScalarCast(registry.get());
158 RegisterScalarComparison(registry.get());
159 RegisterScalarIfElse(registry.get());
160 RegisterScalarNested(registry.get());
161 RegisterScalarSetLookup(registry.get());
162 RegisterScalarStringAscii(registry.get());
163 RegisterScalarTemporalBinary(registry.get());
164 RegisterScalarTemporalUnary(registry.get());
165 RegisterScalarValidity(registry.get());
166
167 RegisterScalarOptions(registry.get());
168
169 // Vector functions
170 RegisterVectorArraySort(registry.get());
171 RegisterVectorHash(registry.get());
172 RegisterVectorNested(registry.get());
173 RegisterVectorReplace(registry.get());
174 RegisterVectorSelection(registry.get());
175 RegisterVectorSort(registry.get());
176
177 RegisterVectorOptions(registry.get());
178
179 // Aggregate functions
180 RegisterHashAggregateBasic(registry.get());
181 RegisterScalarAggregateBasic(registry.get());
182 RegisterScalarAggregateMode(registry.get());
183 RegisterScalarAggregateQuantile(registry.get());
184 RegisterScalarAggregateTDigest(registry.get());
185 RegisterScalarAggregateVariance(registry.get());
186
187 RegisterAggregateOptions(registry.get());
188
189 return registry;
190 }
191
192 } // namespace internal
193
194 FunctionRegistry* GetFunctionRegistry() {
195 static auto g_registry = internal::CreateBuiltInRegistry();
196 return g_registry.get();
197 }
198
199 } // namespace compute
200 } // namespace arrow