]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/cpp/src/arrow/compute/cast.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / cast.cc
CommitLineData
1d09f67e
TL
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/compute/cast.h"
19
20#include <mutex>
21#include <sstream>
22#include <string>
23#include <unordered_map>
24#include <unordered_set>
25#include <utility>
26#include <vector>
27
28#include "arrow/compute/cast_internal.h"
29#include "arrow/compute/exec.h"
30#include "arrow/compute/function_internal.h"
31#include "arrow/compute/kernel.h"
32#include "arrow/compute/kernels/codegen_internal.h"
33#include "arrow/compute/registry.h"
34#include "arrow/util/logging.h"
35#include "arrow/util/reflection_internal.h"
36
37namespace arrow {
38
39using internal::ToTypeName;
40
41namespace compute {
42namespace internal {
43
44// ----------------------------------------------------------------------
45// Function options
46
47namespace {
48
49std::unordered_map<int, std::shared_ptr<CastFunction>> g_cast_table;
50std::once_flag cast_table_initialized;
51
52void AddCastFunctions(const std::vector<std::shared_ptr<CastFunction>>& funcs) {
53 for (const auto& func : funcs) {
54 g_cast_table[static_cast<int>(func->out_type_id())] = func;
55 }
56}
57
58void InitCastTable() {
59 AddCastFunctions(GetBooleanCasts());
60 AddCastFunctions(GetBinaryLikeCasts());
61 AddCastFunctions(GetNestedCasts());
62 AddCastFunctions(GetNumericCasts());
63 AddCastFunctions(GetTemporalCasts());
64 AddCastFunctions(GetDictionaryCasts());
65}
66
67void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); }
68
69// Private version of GetCastFunction with better error reporting
70// if the input type is known.
71Result<std::shared_ptr<CastFunction>> GetCastFunctionInternal(
72 const std::shared_ptr<DataType>& to_type, const DataType* from_type = nullptr) {
73 internal::EnsureInitCastTable();
74 auto it = internal::g_cast_table.find(static_cast<int>(to_type->id()));
75 if (it == internal::g_cast_table.end()) {
76 if (from_type != nullptr) {
77 return Status::NotImplemented("Unsupported cast from ", *from_type, " to ",
78 *to_type,
79 " (no available cast function for target type)");
80 } else {
81 return Status::NotImplemented("Unsupported cast to ", *to_type,
82 " (no available cast function for target type)");
83 }
84 }
85 return it->second;
86}
87
88const FunctionDoc cast_doc{"Cast values to another data type",
89 ("Behavior when values wouldn't fit in the target type\n"
90 "can be controlled through CastOptions."),
91 {"input"},
92 "CastOptions"};
93
94// Metafunction for dispatching to appropriate CastFunction. This corresponds
95// to the standard SQL CAST(expr AS target_type)
96class CastMetaFunction : public MetaFunction {
97 public:
98 CastMetaFunction() : MetaFunction("cast", Arity::Unary(), &cast_doc) {}
99
100 Result<const CastOptions*> ValidateOptions(const FunctionOptions* options) const {
101 auto cast_options = static_cast<const CastOptions*>(options);
102
103 if (cast_options == nullptr || cast_options->to_type == nullptr) {
104 return Status::Invalid(
105 "Cast requires that options be passed with "
106 "the to_type populated");
107 }
108
109 return cast_options;
110 }
111
112 Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
113 const FunctionOptions* options,
114 ExecContext* ctx) const override {
115 ARROW_ASSIGN_OR_RAISE(auto cast_options, ValidateOptions(options));
116 if (args[0].type()->Equals(*cast_options->to_type)) {
117 return args[0];
118 }
119 ARROW_ASSIGN_OR_RAISE(
120 std::shared_ptr<CastFunction> cast_func,
121 GetCastFunctionInternal(cast_options->to_type, args[0].type().get()));
122 return cast_func->Execute(args, options, ctx);
123 }
124};
125
126static auto kCastOptionsType = GetFunctionOptionsType<CastOptions>(
127 arrow::internal::DataMember("to_type", &CastOptions::to_type),
128 arrow::internal::DataMember("allow_int_overflow", &CastOptions::allow_int_overflow),
129 arrow::internal::DataMember("allow_time_truncate", &CastOptions::allow_time_truncate),
130 arrow::internal::DataMember("allow_time_overflow", &CastOptions::allow_time_overflow),
131 arrow::internal::DataMember("allow_decimal_truncate",
132 &CastOptions::allow_decimal_truncate),
133 arrow::internal::DataMember("allow_float_truncate",
134 &CastOptions::allow_float_truncate),
135 arrow::internal::DataMember("allow_invalid_utf8", &CastOptions::allow_invalid_utf8));
136} // namespace
137
138void RegisterScalarCast(FunctionRegistry* registry) {
139 DCHECK_OK(registry->AddFunction(std::make_shared<CastMetaFunction>()));
140 DCHECK_OK(registry->AddFunctionOptionsType(kCastOptionsType));
141}
142} // namespace internal
143
144CastOptions::CastOptions(bool safe)
145 : FunctionOptions(internal::kCastOptionsType),
146 allow_int_overflow(!safe),
147 allow_time_truncate(!safe),
148 allow_time_overflow(!safe),
149 allow_decimal_truncate(!safe),
150 allow_float_truncate(!safe),
151 allow_invalid_utf8(!safe) {}
152
153constexpr char CastOptions::kTypeName[];
154
155CastFunction::CastFunction(std::string name, Type::type out_type_id)
156 : ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr),
157 out_type_id_(out_type_id) {}
158
159Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) {
160 // We use the same KernelInit for every cast
161 kernel.init = internal::CastState::Init;
162 RETURN_NOT_OK(ScalarFunction::AddKernel(kernel));
163 in_type_ids_.push_back(in_type_id);
164 return Status::OK();
165}
166
167Status CastFunction::AddKernel(Type::type in_type_id, std::vector<InputType> in_types,
168 OutputType out_type, ArrayKernelExec exec,
169 NullHandling::type null_handling,
170 MemAllocation::type mem_allocation) {
171 ScalarKernel kernel;
172 kernel.signature = KernelSignature::Make(std::move(in_types), std::move(out_type));
173 kernel.exec = exec;
174 kernel.null_handling = null_handling;
175 kernel.mem_allocation = mem_allocation;
176 return AddKernel(in_type_id, std::move(kernel));
177}
178
179Result<const Kernel*> CastFunction::DispatchExact(
180 const std::vector<ValueDescr>& values) const {
181 RETURN_NOT_OK(CheckArity(values));
182
183 std::vector<const ScalarKernel*> candidate_kernels;
184 for (const auto& kernel : kernels_) {
185 if (kernel.signature->MatchesInputs(values)) {
186 candidate_kernels.push_back(&kernel);
187 }
188 }
189
190 if (candidate_kernels.size() == 0) {
191 return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(),
192 " to ", ToTypeName(out_type_id_), " using function ",
193 this->name());
194 }
195
196 if (candidate_kernels.size() == 1) {
197 // One match, return it
198 return candidate_kernels[0];
199 }
200
201 // Now we are in a casting scenario where we may have both a EXACT_TYPE and
202 // a SAME_TYPE_ID. So we will see if there is an exact match among the
203 // candidate kernels and if not we will just return the first one
204 for (auto kernel : candidate_kernels) {
205 const InputType& arg0 = kernel->signature->in_types()[0];
206 if (arg0.kind() == InputType::EXACT_TYPE) {
207 // Bingo. Return it
208 return kernel;
209 }
210 }
211
212 // We didn't find an exact match. So just return some kernel that matches
213 return candidate_kernels[0];
214}
215
216Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) {
217 return CallFunction("cast", {value}, &options, ctx);
218}
219
220Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type,
221 const CastOptions& options, ExecContext* ctx) {
222 CastOptions options_with_to_type = options;
223 options_with_to_type.to_type = to_type;
224 return Cast(value, options_with_to_type, ctx);
225}
226
227Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type,
228 const CastOptions& options, ExecContext* ctx) {
229 ARROW_ASSIGN_OR_RAISE(Datum result, Cast(Datum(value), to_type, options, ctx));
230 return result.make_array();
231}
232
233Result<std::shared_ptr<CastFunction>> GetCastFunction(
234 const std::shared_ptr<DataType>& to_type) {
235 return internal::GetCastFunctionInternal(to_type);
236}
237
238bool CanCast(const DataType& from_type, const DataType& to_type) {
239 internal::EnsureInitCastTable();
240 auto it = internal::g_cast_table.find(static_cast<int>(to_type.id()));
241 if (it == internal::g_cast_table.end()) {
242 return false;
243 }
244
245 const CastFunction* function = it->second.get();
246 DCHECK_EQ(function->out_type_id(), to_type.id());
247
248 for (auto from_id : function->in_type_ids()) {
249 // XXX should probably check the output type as well
250 if (from_type.id() == from_id) return true;
251 }
252
253 return false;
254}
255
256Result<std::vector<Datum>> Cast(std::vector<Datum> datums, std::vector<ValueDescr> descrs,
257 ExecContext* ctx) {
258 for (size_t i = 0; i != datums.size(); ++i) {
259 if (descrs[i] != datums[i].descr()) {
260 if (descrs[i].shape != datums[i].shape()) {
261 return Status::NotImplemented("casting between Datum shapes");
262 }
263
264 ARROW_ASSIGN_OR_RAISE(datums[i],
265 Cast(datums[i], CastOptions::Safe(descrs[i].type), ctx));
266 }
267 }
268
269 return datums;
270}
271
272} // namespace compute
273} // namespace arrow