]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/arrow/cpp/src/arrow/compute/kernel.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / kernel.cc
diff --git a/ceph/src/arrow/cpp/src/arrow/compute/kernel.cc b/ceph/src/arrow/cpp/src/arrow/compute/kernel.cc
new file mode 100644 (file)
index 0000000..666b73e
--- /dev/null
@@ -0,0 +1,507 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernel.h"
+
+#include <cstddef>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "arrow/buffer.h"
+#include "arrow/compute/exec.h"
+#include "arrow/result.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/hash_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::hash_combine;
+
+static constexpr size_t kHashSeed = 0;
+
+namespace compute {
+
+// ----------------------------------------------------------------------
+// KernelContext
+
+Result<std::shared_ptr<ResizableBuffer>> KernelContext::Allocate(int64_t nbytes) {
+  return AllocateResizableBuffer(nbytes, exec_ctx_->memory_pool());
+}
+
+Result<std::shared_ptr<ResizableBuffer>> KernelContext::AllocateBitmap(int64_t num_bits) {
+  const int64_t nbytes = BitUtil::BytesForBits(num_bits);
+  ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> result,
+                        AllocateResizableBuffer(nbytes, exec_ctx_->memory_pool()));
+  // Since bitmaps are typically written bit by bit, we could leak uninitialized bits.
+  // Make sure all memory is initialized (this also appeases Valgrind).
+  std::memset(result->mutable_data(), 0, result->size());
+  return result;
+}
+
+Status Kernel::InitAll(KernelContext* ctx, const KernelInitArgs& args,
+                       std::vector<std::unique_ptr<KernelState>>* states) {
+  for (auto& state : *states) {
+    ARROW_ASSIGN_OR_RAISE(state, args.kernel->init(ctx, args));
+  }
+  return Status::OK();
+}
+
+Result<std::unique_ptr<KernelState>> ScalarAggregateKernel::MergeAll(
+    const ScalarAggregateKernel* kernel, KernelContext* ctx,
+    std::vector<std::unique_ptr<KernelState>> states) {
+  auto out = std::move(states.back());
+  states.pop_back();
+  ctx->SetState(out.get());
+  for (auto& state : states) {
+    RETURN_NOT_OK(kernel->merge(ctx, std::move(*state), out.get()));
+  }
+  return std::move(out);
+}
+
+// ----------------------------------------------------------------------
+// Some basic TypeMatcher implementations
+
+namespace match {
+
+class SameTypeIdMatcher : public TypeMatcher {
+ public:
+  explicit SameTypeIdMatcher(Type::type accepted_id) : accepted_id_(accepted_id) {}
+
+  bool Matches(const DataType& type) const override { return type.id() == accepted_id_; }
+
+  std::string ToString() const override {
+    std::stringstream ss;
+    ss << "Type::" << ::arrow::internal::ToString(accepted_id_);
+    return ss.str();
+  }
+
+  bool Equals(const TypeMatcher& other) const override {
+    if (this == &other) {
+      return true;
+    }
+    auto casted = dynamic_cast<const SameTypeIdMatcher*>(&other);
+    if (casted == nullptr) {
+      return false;
+    }
+    return this->accepted_id_ == casted->accepted_id_;
+  }
+
+ private:
+  Type::type accepted_id_;
+};
+
+std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id) {
+  return std::make_shared<SameTypeIdMatcher>(type_id);
+}
+
+template <typename ArrowType>
+class TimeUnitMatcher : public TypeMatcher {
+  using ThisType = TimeUnitMatcher<ArrowType>;
+
+ public:
+  explicit TimeUnitMatcher(TimeUnit::type accepted_unit)
+      : accepted_unit_(accepted_unit) {}
+
+  bool Matches(const DataType& type) const override {
+    if (type.id() != ArrowType::type_id) {
+      return false;
+    }
+    const auto& time_type = checked_cast<const ArrowType&>(type);
+    return time_type.unit() == accepted_unit_;
+  }
+
+  bool Equals(const TypeMatcher& other) const override {
+    if (this == &other) {
+      return true;
+    }
+    auto casted = dynamic_cast<const ThisType*>(&other);
+    if (casted == nullptr) {
+      return false;
+    }
+    return this->accepted_unit_ == casted->accepted_unit_;
+  }
+
+  std::string ToString() const override {
+    std::stringstream ss;
+    ss << ArrowType::type_name() << "(" << ::arrow::internal::ToString(accepted_unit_)
+       << ")";
+    return ss.str();
+  }
+
+ private:
+  TimeUnit::type accepted_unit_;
+};
+
+using DurationTypeUnitMatcher = TimeUnitMatcher<DurationType>;
+using Time32TypeUnitMatcher = TimeUnitMatcher<Time32Type>;
+using Time64TypeUnitMatcher = TimeUnitMatcher<Time64Type>;
+using TimestampTypeUnitMatcher = TimeUnitMatcher<TimestampType>;
+
+std::shared_ptr<TypeMatcher> TimestampTypeUnit(TimeUnit::type unit) {
+  return std::make_shared<TimestampTypeUnitMatcher>(unit);
+}
+
+std::shared_ptr<TypeMatcher> Time32TypeUnit(TimeUnit::type unit) {
+  return std::make_shared<Time32TypeUnitMatcher>(unit);
+}
+
+std::shared_ptr<TypeMatcher> Time64TypeUnit(TimeUnit::type unit) {
+  return std::make_shared<Time64TypeUnitMatcher>(unit);
+}
+
+std::shared_ptr<TypeMatcher> DurationTypeUnit(TimeUnit::type unit) {
+  return std::make_shared<DurationTypeUnitMatcher>(unit);
+}
+
+class IntegerMatcher : public TypeMatcher {
+ public:
+  IntegerMatcher() {}
+
+  bool Matches(const DataType& type) const override { return is_integer(type.id()); }
+
+  bool Equals(const TypeMatcher& other) const override {
+    if (this == &other) {
+      return true;
+    }
+    auto casted = dynamic_cast<const IntegerMatcher*>(&other);
+    return casted != nullptr;
+  }
+
+  std::string ToString() const override { return "integer"; }
+};
+
+std::shared_ptr<TypeMatcher> Integer() { return std::make_shared<IntegerMatcher>(); }
+
+class PrimitiveMatcher : public TypeMatcher {
+ public:
+  PrimitiveMatcher() {}
+
+  bool Matches(const DataType& type) const override { return is_primitive(type.id()); }
+
+  bool Equals(const TypeMatcher& other) const override {
+    if (this == &other) {
+      return true;
+    }
+    auto casted = dynamic_cast<const PrimitiveMatcher*>(&other);
+    return casted != nullptr;
+  }
+
+  std::string ToString() const override { return "primitive"; }
+};
+
+std::shared_ptr<TypeMatcher> Primitive() { return std::make_shared<PrimitiveMatcher>(); }
+
+class BinaryLikeMatcher : public TypeMatcher {
+ public:
+  BinaryLikeMatcher() {}
+
+  bool Matches(const DataType& type) const override { return is_binary_like(type.id()); }
+
+  bool Equals(const TypeMatcher& other) const override {
+    if (this == &other) {
+      return true;
+    }
+    auto casted = dynamic_cast<const BinaryLikeMatcher*>(&other);
+    return casted != nullptr;
+  }
+  std::string ToString() const override { return "binary-like"; }
+};
+
+std::shared_ptr<TypeMatcher> BinaryLike() {
+  return std::make_shared<BinaryLikeMatcher>();
+}
+
+class LargeBinaryLikeMatcher : public TypeMatcher {
+ public:
+  LargeBinaryLikeMatcher() {}
+
+  bool Matches(const DataType& type) const override {
+    return is_large_binary_like(type.id());
+  }
+
+  bool Equals(const TypeMatcher& other) const override {
+    if (this == &other) {
+      return true;
+    }
+    auto casted = dynamic_cast<const LargeBinaryLikeMatcher*>(&other);
+    return casted != nullptr;
+  }
+  std::string ToString() const override { return "large-binary-like"; }
+};
+
+class FixedSizeBinaryLikeMatcher : public TypeMatcher {
+ public:
+  FixedSizeBinaryLikeMatcher() {}
+
+  bool Matches(const DataType& type) const override {
+    return is_fixed_size_binary(type.id());
+  }
+
+  bool Equals(const TypeMatcher& other) const override {
+    if (this == &other) {
+      return true;
+    }
+    auto casted = dynamic_cast<const FixedSizeBinaryLikeMatcher*>(&other);
+    return casted != nullptr;
+  }
+  std::string ToString() const override { return "fixed-size-binary-like"; }
+};
+
+std::shared_ptr<TypeMatcher> LargeBinaryLike() {
+  return std::make_shared<LargeBinaryLikeMatcher>();
+}
+
+std::shared_ptr<TypeMatcher> FixedSizeBinaryLike() {
+  return std::make_shared<FixedSizeBinaryLikeMatcher>();
+}
+
+}  // namespace match
+
+// ----------------------------------------------------------------------
+// InputType
+
+size_t InputType::Hash() const {
+  size_t result = kHashSeed;
+  hash_combine(result, static_cast<int>(shape_));
+  hash_combine(result, static_cast<int>(kind_));
+  switch (kind_) {
+    case InputType::EXACT_TYPE:
+      hash_combine(result, type_->Hash());
+      break;
+    default:
+      break;
+  }
+  return result;
+}
+
+std::string InputType::ToString() const {
+  std::stringstream ss;
+  switch (shape_) {
+    case ValueDescr::ANY:
+      ss << "any";
+      break;
+    case ValueDescr::ARRAY:
+      ss << "array";
+      break;
+    case ValueDescr::SCALAR:
+      ss << "scalar";
+      break;
+    default:
+      DCHECK(false);
+      break;
+  }
+  ss << "[";
+  switch (kind_) {
+    case InputType::ANY_TYPE:
+      ss << "any";
+      break;
+    case InputType::EXACT_TYPE:
+      ss << type_->ToString();
+      break;
+    case InputType::USE_TYPE_MATCHER: {
+      ss << type_matcher_->ToString();
+    } break;
+    default:
+      DCHECK(false);
+      break;
+  }
+  ss << "]";
+  return ss.str();
+}
+
+bool InputType::Equals(const InputType& other) const {
+  if (this == &other) {
+    return true;
+  }
+  if (kind_ != other.kind_ || shape_ != other.shape_) {
+    return false;
+  }
+  switch (kind_) {
+    case InputType::ANY_TYPE:
+      return true;
+    case InputType::EXACT_TYPE:
+      return type_->Equals(*other.type_);
+    case InputType::USE_TYPE_MATCHER:
+      return type_matcher_->Equals(*other.type_matcher_);
+    default:
+      return false;
+  }
+}
+
+bool InputType::Matches(const ValueDescr& descr) const {
+  if (shape_ != ValueDescr::ANY && descr.shape != shape_) {
+    return false;
+  }
+  switch (kind_) {
+    case InputType::EXACT_TYPE:
+      return type_->Equals(*descr.type);
+    case InputType::USE_TYPE_MATCHER:
+      return type_matcher_->Matches(*descr.type);
+    default:
+      // ANY_TYPE
+      return true;
+  }
+}
+
+bool InputType::Matches(const Datum& value) const { return Matches(value.descr()); }
+
+const std::shared_ptr<DataType>& InputType::type() const {
+  DCHECK_EQ(InputType::EXACT_TYPE, kind_);
+  return type_;
+}
+
+const TypeMatcher& InputType::type_matcher() const {
+  DCHECK_EQ(InputType::USE_TYPE_MATCHER, kind_);
+  return *type_matcher_;
+}
+
+// ----------------------------------------------------------------------
+// OutputType
+
+OutputType::OutputType(ValueDescr descr) : OutputType(descr.type) {
+  shape_ = descr.shape;
+}
+
+Result<ValueDescr> OutputType::Resolve(KernelContext* ctx,
+                                       const std::vector<ValueDescr>& args) const {
+  ValueDescr::Shape broadcasted_shape = GetBroadcastShape(args);
+  if (kind_ == OutputType::FIXED) {
+    return ValueDescr(type_, shape_ == ValueDescr::ANY ? broadcasted_shape : shape_);
+  } else {
+    ARROW_ASSIGN_OR_RAISE(ValueDescr resolved_descr, resolver_(ctx, args));
+    if (resolved_descr.shape == ValueDescr::ANY) {
+      resolved_descr.shape = broadcasted_shape;
+    }
+    return resolved_descr;
+  }
+}
+
+const std::shared_ptr<DataType>& OutputType::type() const {
+  DCHECK_EQ(FIXED, kind_);
+  return type_;
+}
+
+const OutputType::Resolver& OutputType::resolver() const {
+  DCHECK_EQ(COMPUTED, kind_);
+  return resolver_;
+}
+
+std::string OutputType::ToString() const {
+  if (kind_ == OutputType::FIXED) {
+    return type_->ToString();
+  } else {
+    return "computed";
+  }
+}
+
+// ----------------------------------------------------------------------
+// KernelSignature
+
+KernelSignature::KernelSignature(std::vector<InputType> in_types, OutputType out_type,
+                                 bool is_varargs)
+    : in_types_(std::move(in_types)),
+      out_type_(std::move(out_type)),
+      is_varargs_(is_varargs),
+      hash_code_(0) {
+  DCHECK(!is_varargs || (is_varargs && (in_types_.size() >= 1)));
+}
+
+std::shared_ptr<KernelSignature> KernelSignature::Make(std::vector<InputType> in_types,
+                                                       OutputType out_type,
+                                                       bool is_varargs) {
+  return std::make_shared<KernelSignature>(std::move(in_types), std::move(out_type),
+                                           is_varargs);
+}
+
+bool KernelSignature::Equals(const KernelSignature& other) const {
+  if (is_varargs_ != other.is_varargs_) {
+    return false;
+  }
+  if (in_types_.size() != other.in_types_.size()) {
+    return false;
+  }
+  for (size_t i = 0; i < in_types_.size(); ++i) {
+    if (!in_types_[i].Equals(other.in_types_[i])) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool KernelSignature::MatchesInputs(const std::vector<ValueDescr>& args) const {
+  if (is_varargs_) {
+    for (size_t i = 0; i < args.size(); ++i) {
+      if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(args[i])) {
+        return false;
+      }
+    }
+  } else {
+    if (args.size() != in_types_.size()) {
+      return false;
+    }
+    for (size_t i = 0; i < in_types_.size(); ++i) {
+      if (!in_types_[i].Matches(args[i])) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
+size_t KernelSignature::Hash() const {
+  if (hash_code_ != 0) {
+    return hash_code_;
+  }
+  size_t result = kHashSeed;
+  for (const auto& in_type : in_types_) {
+    hash_combine(result, in_type.Hash());
+  }
+  hash_code_ = result;
+  return result;
+}
+
+std::string KernelSignature::ToString() const {
+  std::stringstream ss;
+
+  if (is_varargs_) {
+    ss << "varargs[";
+  } else {
+    ss << "(";
+  }
+  for (size_t i = 0; i < in_types_.size(); ++i) {
+    if (i > 0) {
+      ss << ", ";
+    }
+    ss << in_types_[i].ToString();
+  }
+  if (is_varargs_) {
+    ss << "]";
+  } else {
+    ss << ")";
+  }
+  ss << " -> " << out_type_.ToString();
+  return ss.str();
+}
+
+}  // namespace compute
+}  // namespace arrow