]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/arrow/cpp/src/gandiva/expression_registry.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / gandiva / expression_registry.cc
diff --git a/ceph/src/arrow/cpp/src/gandiva/expression_registry.cc b/ceph/src/arrow/cpp/src/gandiva/expression_registry.cc
new file mode 100644 (file)
index 0000000..c3a08fd
--- /dev/null
@@ -0,0 +1,187 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/expression_registry.h"
+
+#include "gandiva/function_registry.h"
+#include "gandiva/llvm_types.h"
+
+namespace gandiva {
+
+ExpressionRegistry::ExpressionRegistry() {
+  function_registry_.reset(new FunctionRegistry());
+}
+
+ExpressionRegistry::~ExpressionRegistry() {}
+
+// to be used only to create function_signature_start
+ExpressionRegistry::FunctionSignatureIterator::FunctionSignatureIterator(
+    native_func_iterator_type nf_it, native_func_iterator_type nf_it_end)
+    : native_func_it_{nf_it},
+      native_func_it_end_{nf_it_end},
+      func_sig_it_{&(nf_it->signatures().front())} {}
+
+// to be used only to create function_signature_end
+ExpressionRegistry::FunctionSignatureIterator::FunctionSignatureIterator(
+    func_sig_iterator_type fs_it)
+    : native_func_it_{nullptr}, native_func_it_end_{nullptr}, func_sig_it_{fs_it} {}
+
+const ExpressionRegistry::FunctionSignatureIterator
+ExpressionRegistry::function_signature_begin() {
+  return FunctionSignatureIterator(function_registry_->begin(),
+                                   function_registry_->end());
+}
+
+const ExpressionRegistry::FunctionSignatureIterator
+ExpressionRegistry::function_signature_end() const {
+  return FunctionSignatureIterator(&(*(function_registry_->back()->signatures().end())));
+}
+
+bool ExpressionRegistry::FunctionSignatureIterator::operator!=(
+    const FunctionSignatureIterator& func_sign_it) {
+  return func_sign_it.func_sig_it_ != this->func_sig_it_;
+}
+
+FunctionSignature ExpressionRegistry::FunctionSignatureIterator::operator*() {
+  return *func_sig_it_;
+}
+
+ExpressionRegistry::func_sig_iterator_type ExpressionRegistry::FunctionSignatureIterator::
+operator++(int increment) {
+  ++func_sig_it_;
+  // point func_sig_it_ to first signature of next nativefunction if func_sig_it_ is
+  // pointing to end
+  if (func_sig_it_ == &(*native_func_it_->signatures().end())) {
+    ++native_func_it_;
+    if (native_func_it_ == native_func_it_end_) {  // last native function
+      return func_sig_it_;
+    }
+    func_sig_it_ = &(native_func_it_->signatures().front());
+  }
+  return func_sig_it_;
+}
+
+static void AddArrowTypesToVector(arrow::Type::type type, DataTypeVector& vector);
+
+static DataTypeVector InitSupportedTypes() {
+  DataTypeVector data_type_vector;
+  llvm::LLVMContext llvm_context;
+  LLVMTypes llvm_types(llvm_context);
+  auto supported_arrow_types = llvm_types.GetSupportedArrowTypes();
+  for (auto& type_id : supported_arrow_types) {
+    AddArrowTypesToVector(type_id, data_type_vector);
+  }
+  return data_type_vector;
+}
+
+DataTypeVector ExpressionRegistry::supported_types_ = InitSupportedTypes();
+
+static void AddArrowTypesToVector(arrow::Type::type type, DataTypeVector& vector) {
+  switch (type) {
+    case arrow::Type::type::BOOL:
+      vector.push_back(arrow::boolean());
+      break;
+    case arrow::Type::type::UINT8:
+      vector.push_back(arrow::uint8());
+      break;
+    case arrow::Type::type::INT8:
+      vector.push_back(arrow::int8());
+      break;
+    case arrow::Type::type::UINT16:
+      vector.push_back(arrow::uint16());
+      break;
+    case arrow::Type::type::INT16:
+      vector.push_back(arrow::int16());
+      break;
+    case arrow::Type::type::UINT32:
+      vector.push_back(arrow::uint32());
+      break;
+    case arrow::Type::type::INT32:
+      vector.push_back(arrow::int32());
+      break;
+    case arrow::Type::type::UINT64:
+      vector.push_back(arrow::uint64());
+      break;
+    case arrow::Type::type::INT64:
+      vector.push_back(arrow::int64());
+      break;
+    case arrow::Type::type::HALF_FLOAT:
+      vector.push_back(arrow::float16());
+      break;
+    case arrow::Type::type::FLOAT:
+      vector.push_back(arrow::float32());
+      break;
+    case arrow::Type::type::DOUBLE:
+      vector.push_back(arrow::float64());
+      break;
+    case arrow::Type::type::STRING:
+      vector.push_back(arrow::utf8());
+      break;
+    case arrow::Type::type::BINARY:
+      vector.push_back(arrow::binary());
+      break;
+    case arrow::Type::type::DATE32:
+      vector.push_back(arrow::date32());
+      break;
+    case arrow::Type::type::DATE64:
+      vector.push_back(arrow::date64());
+      break;
+    case arrow::Type::type::TIMESTAMP:
+      vector.push_back(arrow::timestamp(arrow::TimeUnit::SECOND));
+      vector.push_back(arrow::timestamp(arrow::TimeUnit::MILLI));
+      vector.push_back(arrow::timestamp(arrow::TimeUnit::NANO));
+      vector.push_back(arrow::timestamp(arrow::TimeUnit::MICRO));
+      break;
+    case arrow::Type::type::TIME32:
+      vector.push_back(arrow::time32(arrow::TimeUnit::SECOND));
+      vector.push_back(arrow::time32(arrow::TimeUnit::MILLI));
+      break;
+    case arrow::Type::type::TIME64:
+      vector.push_back(arrow::time64(arrow::TimeUnit::MICRO));
+      vector.push_back(arrow::time64(arrow::TimeUnit::NANO));
+      break;
+    case arrow::Type::type::NA:
+      vector.push_back(arrow::null());
+      break;
+    case arrow::Type::type::DECIMAL:
+      vector.push_back(arrow::decimal(38, 0));
+      break;
+    case arrow::Type::type::INTERVAL_MONTHS:
+      vector.push_back(arrow::month_interval());
+      break;
+    case arrow::Type::type::INTERVAL_DAY_TIME:
+      vector.push_back(arrow::day_time_interval());
+      break;
+    default:
+      // Unsupported types. test ensures that
+      // when one of these are added build breaks.
+      DCHECK(false);
+  }
+}
+
+std::vector<std::shared_ptr<FunctionSignature>> GetRegisteredFunctionSignatures() {
+  ExpressionRegistry registry;
+  std::vector<std::shared_ptr<FunctionSignature>> signatures;
+  for (auto iter = registry.function_signature_begin();
+       iter != registry.function_signature_end(); iter++) {
+    signatures.push_back(std::make_shared<FunctionSignature>(
+        (*iter).base_name(), (*iter).param_types(), (*iter).ret_type()));
+  }
+  return signatures;
+}
+
+}  // namespace gandiva