]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.h
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / kernels / row_encoder.h
diff --git a/ceph/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.h b/ceph/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.h
new file mode 100644 (file)
index 0000000..40509f2
--- /dev/null
@@ -0,0 +1,267 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace internal {
+
+struct KeyEncoder {
+  // the first byte of an encoded key is used to indicate nullity
+  static constexpr bool kExtraByteForNull = true;
+
+  static constexpr uint8_t kNullByte = 1;
+  static constexpr uint8_t kValidByte = 0;
+
+  virtual ~KeyEncoder() = default;
+
+  virtual void AddLength(const Datum&, int64_t batch_length, int32_t* lengths) = 0;
+
+  virtual void AddLengthNull(int32_t* length) = 0;
+
+  virtual Status Encode(const Datum&, int64_t batch_length, uint8_t** encoded_bytes) = 0;
+
+  virtual void EncodeNull(uint8_t** encoded_bytes) = 0;
+
+  virtual Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes,
+                                                    int32_t length, MemoryPool*) = 0;
+
+  // extract the null bitmap from the leading nullity bytes of encoded keys
+  static Status DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes,
+                            std::shared_ptr<Buffer>* null_bitmap, int32_t* null_count);
+
+  static bool IsNull(const uint8_t* encoded_bytes) {
+    return encoded_bytes[0] == kNullByte;
+  }
+};
+
+struct BooleanKeyEncoder : KeyEncoder {
+  static constexpr int kByteWidth = 1;
+
+  void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override;
+
+  void AddLengthNull(int32_t* length) override;
+
+  Status Encode(const Datum& data, int64_t batch_length,
+                uint8_t** encoded_bytes) override;
+
+  void EncodeNull(uint8_t** encoded_bytes) override;
+
+  Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+                                            MemoryPool* pool) override;
+};
+
+struct FixedWidthKeyEncoder : KeyEncoder {
+  explicit FixedWidthKeyEncoder(std::shared_ptr<DataType> type)
+      : type_(std::move(type)),
+        byte_width_(checked_cast<const FixedWidthType&>(*type_).bit_width() / 8) {}
+
+  void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override;
+
+  void AddLengthNull(int32_t* length) override;
+
+  Status Encode(const Datum& data, int64_t batch_length,
+                uint8_t** encoded_bytes) override;
+
+  void EncodeNull(uint8_t** encoded_bytes) override;
+
+  Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+                                            MemoryPool* pool) override;
+
+  std::shared_ptr<DataType> type_;
+  int byte_width_;
+};
+
+struct DictionaryKeyEncoder : FixedWidthKeyEncoder {
+  DictionaryKeyEncoder(std::shared_ptr<DataType> type, MemoryPool* pool)
+      : FixedWidthKeyEncoder(std::move(type)), pool_(pool) {}
+
+  Status Encode(const Datum& data, int64_t batch_length,
+                uint8_t** encoded_bytes) override;
+
+  Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+                                            MemoryPool* pool) override;
+
+  MemoryPool* pool_;
+  std::shared_ptr<Array> dictionary_;
+};
+
+template <typename T>
+struct VarLengthKeyEncoder : KeyEncoder {
+  using Offset = typename T::offset_type;
+
+  void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override {
+    if (data.is_array()) {
+      int64_t i = 0;
+      VisitArrayDataInline<T>(
+          *data.array(),
+          [&](util::string_view bytes) {
+            lengths[i++] +=
+                kExtraByteForNull + sizeof(Offset) + static_cast<int32_t>(bytes.size());
+          },
+          [&] { lengths[i++] += kExtraByteForNull + sizeof(Offset); });
+    } else {
+      const Scalar& scalar = *data.scalar();
+      const int32_t buffer_size =
+          scalar.is_valid ? static_cast<int32_t>(UnboxScalar<T>::Unbox(scalar).size())
+                          : 0;
+      for (int64_t i = 0; i < batch_length; i++) {
+        lengths[i] += kExtraByteForNull + sizeof(Offset) + buffer_size;
+      }
+    }
+  }
+
+  void AddLengthNull(int32_t* length) override {
+    *length += kExtraByteForNull + sizeof(Offset);
+  }
+
+  Status Encode(const Datum& data, int64_t batch_length,
+                uint8_t** encoded_bytes) override {
+    if (data.is_array()) {
+      VisitArrayDataInline<T>(
+          *data.array(),
+          [&](util::string_view bytes) {
+            auto& encoded_ptr = *encoded_bytes++;
+            *encoded_ptr++ = kValidByte;
+            util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size()));
+            encoded_ptr += sizeof(Offset);
+            memcpy(encoded_ptr, bytes.data(), bytes.size());
+            encoded_ptr += bytes.size();
+          },
+          [&] {
+            auto& encoded_ptr = *encoded_bytes++;
+            *encoded_ptr++ = kNullByte;
+            util::SafeStore(encoded_ptr, static_cast<Offset>(0));
+            encoded_ptr += sizeof(Offset);
+          });
+    } else {
+      const auto& scalar = data.scalar_as<BaseBinaryScalar>();
+      if (scalar.is_valid) {
+        const auto& bytes = *scalar.value;
+        for (int64_t i = 0; i < batch_length; i++) {
+          auto& encoded_ptr = *encoded_bytes++;
+          *encoded_ptr++ = kValidByte;
+          util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size()));
+          encoded_ptr += sizeof(Offset);
+          memcpy(encoded_ptr, bytes.data(), bytes.size());
+          encoded_ptr += bytes.size();
+        }
+      } else {
+        for (int64_t i = 0; i < batch_length; i++) {
+          auto& encoded_ptr = *encoded_bytes++;
+          *encoded_ptr++ = kNullByte;
+          util::SafeStore(encoded_ptr, static_cast<Offset>(0));
+          encoded_ptr += sizeof(Offset);
+        }
+      }
+    }
+    return Status::OK();
+  }
+
+  void EncodeNull(uint8_t** encoded_bytes) override {
+    auto& encoded_ptr = *encoded_bytes;
+    *encoded_ptr++ = kNullByte;
+    util::SafeStore(encoded_ptr, static_cast<Offset>(0));
+    encoded_ptr += sizeof(Offset);
+  }
+
+  Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+                                            MemoryPool* pool) override {
+    std::shared_ptr<Buffer> null_buf;
+    int32_t null_count;
+    ARROW_RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count));
+
+    Offset length_sum = 0;
+    for (int32_t i = 0; i < length; ++i) {
+      length_sum += util::SafeLoadAs<Offset>(encoded_bytes[i]);
+    }
+
+    ARROW_ASSIGN_OR_RAISE(auto offset_buf,
+                          AllocateBuffer(sizeof(Offset) * (1 + length), pool));
+    ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length_sum));
+
+    auto raw_offsets = reinterpret_cast<Offset*>(offset_buf->mutable_data());
+    auto raw_keys = key_buf->mutable_data();
+
+    Offset current_offset = 0;
+    for (int32_t i = 0; i < length; ++i) {
+      raw_offsets[i] = current_offset;
+
+      auto key_length = util::SafeLoadAs<Offset>(encoded_bytes[i]);
+      encoded_bytes[i] += sizeof(Offset);
+
+      memcpy(raw_keys + current_offset, encoded_bytes[i], key_length);
+      encoded_bytes[i] += key_length;
+
+      current_offset += key_length;
+    }
+    raw_offsets[length] = current_offset;
+
+    return ArrayData::Make(
+        type_, length, {std::move(null_buf), std::move(offset_buf), std::move(key_buf)},
+        null_count);
+  }
+
+  explicit VarLengthKeyEncoder(std::shared_ptr<DataType> type) : type_(std::move(type)) {}
+
+  std::shared_ptr<DataType> type_;
+};
+
+class ARROW_EXPORT RowEncoder {
+ public:
+  static constexpr int kRowIdForNulls() { return -1; }
+
+  void Init(const std::vector<ValueDescr>& column_types, ExecContext* ctx);
+  void Clear();
+  Status EncodeAndAppend(const ExecBatch& batch);
+  Result<ExecBatch> Decode(int64_t num_rows, const int32_t* row_ids);
+
+  inline std::string encoded_row(int32_t i) const {
+    if (i == kRowIdForNulls()) {
+      return std::string(reinterpret_cast<const char*>(encoded_nulls_.data()),
+                         encoded_nulls_.size());
+    }
+    int32_t row_length = offsets_[i + 1] - offsets_[i];
+    return std::string(reinterpret_cast<const char*>(bytes_.data() + offsets_[i]),
+                       row_length);
+  }
+
+  int32_t num_rows() const {
+    return offsets_.size() == 0 ? 0 : static_cast<int32_t>(offsets_.size() - 1);
+  }
+
+ private:
+  ExecContext* ctx_;
+  std::vector<std::shared_ptr<KeyEncoder>> encoders_;
+  std::vector<int32_t> offsets_;
+  std::vector<uint8_t> bytes_;
+  std::vector<uint8_t> encoded_nulls_;
+};
+
+}  // namespace internal
+}  // namespace compute
+}  // namespace arrow