]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/compute/cast.h" | |
19 | ||
20 | #include <mutex> | |
21 | #include <sstream> | |
22 | #include <string> | |
23 | #include <unordered_map> | |
24 | #include <unordered_set> | |
25 | #include <utility> | |
26 | #include <vector> | |
27 | ||
28 | #include "arrow/compute/cast_internal.h" | |
29 | #include "arrow/compute/exec.h" | |
30 | #include "arrow/compute/function_internal.h" | |
31 | #include "arrow/compute/kernel.h" | |
32 | #include "arrow/compute/kernels/codegen_internal.h" | |
33 | #include "arrow/compute/registry.h" | |
34 | #include "arrow/util/logging.h" | |
35 | #include "arrow/util/reflection_internal.h" | |
36 | ||
37 | namespace arrow { | |
38 | ||
39 | using internal::ToTypeName; | |
40 | ||
41 | namespace compute { | |
42 | namespace internal { | |
43 | ||
44 | // ---------------------------------------------------------------------- | |
45 | // Function options | |
46 | ||
47 | namespace { | |
48 | ||
49 | std::unordered_map<int, std::shared_ptr<CastFunction>> g_cast_table; | |
50 | std::once_flag cast_table_initialized; | |
51 | ||
52 | void AddCastFunctions(const std::vector<std::shared_ptr<CastFunction>>& funcs) { | |
53 | for (const auto& func : funcs) { | |
54 | g_cast_table[static_cast<int>(func->out_type_id())] = func; | |
55 | } | |
56 | } | |
57 | ||
58 | void InitCastTable() { | |
59 | AddCastFunctions(GetBooleanCasts()); | |
60 | AddCastFunctions(GetBinaryLikeCasts()); | |
61 | AddCastFunctions(GetNestedCasts()); | |
62 | AddCastFunctions(GetNumericCasts()); | |
63 | AddCastFunctions(GetTemporalCasts()); | |
64 | AddCastFunctions(GetDictionaryCasts()); | |
65 | } | |
66 | ||
67 | void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); } | |
68 | ||
69 | // Private version of GetCastFunction with better error reporting | |
70 | // if the input type is known. | |
71 | Result<std::shared_ptr<CastFunction>> GetCastFunctionInternal( | |
72 | const std::shared_ptr<DataType>& to_type, const DataType* from_type = nullptr) { | |
73 | internal::EnsureInitCastTable(); | |
74 | auto it = internal::g_cast_table.find(static_cast<int>(to_type->id())); | |
75 | if (it == internal::g_cast_table.end()) { | |
76 | if (from_type != nullptr) { | |
77 | return Status::NotImplemented("Unsupported cast from ", *from_type, " to ", | |
78 | *to_type, | |
79 | " (no available cast function for target type)"); | |
80 | } else { | |
81 | return Status::NotImplemented("Unsupported cast to ", *to_type, | |
82 | " (no available cast function for target type)"); | |
83 | } | |
84 | } | |
85 | return it->second; | |
86 | } | |
87 | ||
88 | const FunctionDoc cast_doc{"Cast values to another data type", | |
89 | ("Behavior when values wouldn't fit in the target type\n" | |
90 | "can be controlled through CastOptions."), | |
91 | {"input"}, | |
92 | "CastOptions"}; | |
93 | ||
94 | // Metafunction for dispatching to appropriate CastFunction. This corresponds | |
95 | // to the standard SQL CAST(expr AS target_type) | |
96 | class CastMetaFunction : public MetaFunction { | |
97 | public: | |
98 | CastMetaFunction() : MetaFunction("cast", Arity::Unary(), &cast_doc) {} | |
99 | ||
100 | Result<const CastOptions*> ValidateOptions(const FunctionOptions* options) const { | |
101 | auto cast_options = static_cast<const CastOptions*>(options); | |
102 | ||
103 | if (cast_options == nullptr || cast_options->to_type == nullptr) { | |
104 | return Status::Invalid( | |
105 | "Cast requires that options be passed with " | |
106 | "the to_type populated"); | |
107 | } | |
108 | ||
109 | return cast_options; | |
110 | } | |
111 | ||
112 | Result<Datum> ExecuteImpl(const std::vector<Datum>& args, | |
113 | const FunctionOptions* options, | |
114 | ExecContext* ctx) const override { | |
115 | ARROW_ASSIGN_OR_RAISE(auto cast_options, ValidateOptions(options)); | |
116 | if (args[0].type()->Equals(*cast_options->to_type)) { | |
117 | return args[0]; | |
118 | } | |
119 | ARROW_ASSIGN_OR_RAISE( | |
120 | std::shared_ptr<CastFunction> cast_func, | |
121 | GetCastFunctionInternal(cast_options->to_type, args[0].type().get())); | |
122 | return cast_func->Execute(args, options, ctx); | |
123 | } | |
124 | }; | |
125 | ||
126 | static auto kCastOptionsType = GetFunctionOptionsType<CastOptions>( | |
127 | arrow::internal::DataMember("to_type", &CastOptions::to_type), | |
128 | arrow::internal::DataMember("allow_int_overflow", &CastOptions::allow_int_overflow), | |
129 | arrow::internal::DataMember("allow_time_truncate", &CastOptions::allow_time_truncate), | |
130 | arrow::internal::DataMember("allow_time_overflow", &CastOptions::allow_time_overflow), | |
131 | arrow::internal::DataMember("allow_decimal_truncate", | |
132 | &CastOptions::allow_decimal_truncate), | |
133 | arrow::internal::DataMember("allow_float_truncate", | |
134 | &CastOptions::allow_float_truncate), | |
135 | arrow::internal::DataMember("allow_invalid_utf8", &CastOptions::allow_invalid_utf8)); | |
136 | } // namespace | |
137 | ||
138 | void RegisterScalarCast(FunctionRegistry* registry) { | |
139 | DCHECK_OK(registry->AddFunction(std::make_shared<CastMetaFunction>())); | |
140 | DCHECK_OK(registry->AddFunctionOptionsType(kCastOptionsType)); | |
141 | } | |
142 | } // namespace internal | |
143 | ||
144 | CastOptions::CastOptions(bool safe) | |
145 | : FunctionOptions(internal::kCastOptionsType), | |
146 | allow_int_overflow(!safe), | |
147 | allow_time_truncate(!safe), | |
148 | allow_time_overflow(!safe), | |
149 | allow_decimal_truncate(!safe), | |
150 | allow_float_truncate(!safe), | |
151 | allow_invalid_utf8(!safe) {} | |
152 | ||
153 | constexpr char CastOptions::kTypeName[]; | |
154 | ||
155 | CastFunction::CastFunction(std::string name, Type::type out_type_id) | |
156 | : ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr), | |
157 | out_type_id_(out_type_id) {} | |
158 | ||
159 | Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) { | |
160 | // We use the same KernelInit for every cast | |
161 | kernel.init = internal::CastState::Init; | |
162 | RETURN_NOT_OK(ScalarFunction::AddKernel(kernel)); | |
163 | in_type_ids_.push_back(in_type_id); | |
164 | return Status::OK(); | |
165 | } | |
166 | ||
167 | Status CastFunction::AddKernel(Type::type in_type_id, std::vector<InputType> in_types, | |
168 | OutputType out_type, ArrayKernelExec exec, | |
169 | NullHandling::type null_handling, | |
170 | MemAllocation::type mem_allocation) { | |
171 | ScalarKernel kernel; | |
172 | kernel.signature = KernelSignature::Make(std::move(in_types), std::move(out_type)); | |
173 | kernel.exec = exec; | |
174 | kernel.null_handling = null_handling; | |
175 | kernel.mem_allocation = mem_allocation; | |
176 | return AddKernel(in_type_id, std::move(kernel)); | |
177 | } | |
178 | ||
179 | Result<const Kernel*> CastFunction::DispatchExact( | |
180 | const std::vector<ValueDescr>& values) const { | |
181 | RETURN_NOT_OK(CheckArity(values)); | |
182 | ||
183 | std::vector<const ScalarKernel*> candidate_kernels; | |
184 | for (const auto& kernel : kernels_) { | |
185 | if (kernel.signature->MatchesInputs(values)) { | |
186 | candidate_kernels.push_back(&kernel); | |
187 | } | |
188 | } | |
189 | ||
190 | if (candidate_kernels.size() == 0) { | |
191 | return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(), | |
192 | " to ", ToTypeName(out_type_id_), " using function ", | |
193 | this->name()); | |
194 | } | |
195 | ||
196 | if (candidate_kernels.size() == 1) { | |
197 | // One match, return it | |
198 | return candidate_kernels[0]; | |
199 | } | |
200 | ||
201 | // Now we are in a casting scenario where we may have both a EXACT_TYPE and | |
202 | // a SAME_TYPE_ID. So we will see if there is an exact match among the | |
203 | // candidate kernels and if not we will just return the first one | |
204 | for (auto kernel : candidate_kernels) { | |
205 | const InputType& arg0 = kernel->signature->in_types()[0]; | |
206 | if (arg0.kind() == InputType::EXACT_TYPE) { | |
207 | // Bingo. Return it | |
208 | return kernel; | |
209 | } | |
210 | } | |
211 | ||
212 | // We didn't find an exact match. So just return some kernel that matches | |
213 | return candidate_kernels[0]; | |
214 | } | |
215 | ||
216 | Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) { | |
217 | return CallFunction("cast", {value}, &options, ctx); | |
218 | } | |
219 | ||
220 | Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type, | |
221 | const CastOptions& options, ExecContext* ctx) { | |
222 | CastOptions options_with_to_type = options; | |
223 | options_with_to_type.to_type = to_type; | |
224 | return Cast(value, options_with_to_type, ctx); | |
225 | } | |
226 | ||
227 | Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type, | |
228 | const CastOptions& options, ExecContext* ctx) { | |
229 | ARROW_ASSIGN_OR_RAISE(Datum result, Cast(Datum(value), to_type, options, ctx)); | |
230 | return result.make_array(); | |
231 | } | |
232 | ||
233 | Result<std::shared_ptr<CastFunction>> GetCastFunction( | |
234 | const std::shared_ptr<DataType>& to_type) { | |
235 | return internal::GetCastFunctionInternal(to_type); | |
236 | } | |
237 | ||
238 | bool CanCast(const DataType& from_type, const DataType& to_type) { | |
239 | internal::EnsureInitCastTable(); | |
240 | auto it = internal::g_cast_table.find(static_cast<int>(to_type.id())); | |
241 | if (it == internal::g_cast_table.end()) { | |
242 | return false; | |
243 | } | |
244 | ||
245 | const CastFunction* function = it->second.get(); | |
246 | DCHECK_EQ(function->out_type_id(), to_type.id()); | |
247 | ||
248 | for (auto from_id : function->in_type_ids()) { | |
249 | // XXX should probably check the output type as well | |
250 | if (from_type.id() == from_id) return true; | |
251 | } | |
252 | ||
253 | return false; | |
254 | } | |
255 | ||
256 | Result<std::vector<Datum>> Cast(std::vector<Datum> datums, std::vector<ValueDescr> descrs, | |
257 | ExecContext* ctx) { | |
258 | for (size_t i = 0; i != datums.size(); ++i) { | |
259 | if (descrs[i] != datums[i].descr()) { | |
260 | if (descrs[i].shape != datums[i].shape()) { | |
261 | return Status::NotImplemented("casting between Datum shapes"); | |
262 | } | |
263 | ||
264 | ARROW_ASSIGN_OR_RAISE(datums[i], | |
265 | Cast(datums[i], CastOptions::Safe(descrs[i].type), ctx)); | |
266 | } | |
267 | } | |
268 | ||
269 | return datums; | |
270 | } | |
271 | ||
272 | } // namespace compute | |
273 | } // namespace arrow |