--- /dev/null
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace internal {
+
+struct KeyEncoder {
+ // the first byte of an encoded key is used to indicate nullity
+ static constexpr bool kExtraByteForNull = true;
+
+ static constexpr uint8_t kNullByte = 1;
+ static constexpr uint8_t kValidByte = 0;
+
+ virtual ~KeyEncoder() = default;
+
+ virtual void AddLength(const Datum&, int64_t batch_length, int32_t* lengths) = 0;
+
+ virtual void AddLengthNull(int32_t* length) = 0;
+
+ virtual Status Encode(const Datum&, int64_t batch_length, uint8_t** encoded_bytes) = 0;
+
+ virtual void EncodeNull(uint8_t** encoded_bytes) = 0;
+
+ virtual Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes,
+ int32_t length, MemoryPool*) = 0;
+
+ // extract the null bitmap from the leading nullity bytes of encoded keys
+ static Status DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes,
+ std::shared_ptr<Buffer>* null_bitmap, int32_t* null_count);
+
+ static bool IsNull(const uint8_t* encoded_bytes) {
+ return encoded_bytes[0] == kNullByte;
+ }
+};
+
+struct BooleanKeyEncoder : KeyEncoder {
+ static constexpr int kByteWidth = 1;
+
+ void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override;
+
+ void AddLengthNull(int32_t* length) override;
+
+ Status Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) override;
+
+ void EncodeNull(uint8_t** encoded_bytes) override;
+
+ Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+ MemoryPool* pool) override;
+};
+
+struct FixedWidthKeyEncoder : KeyEncoder {
+ explicit FixedWidthKeyEncoder(std::shared_ptr<DataType> type)
+ : type_(std::move(type)),
+ byte_width_(checked_cast<const FixedWidthType&>(*type_).bit_width() / 8) {}
+
+ void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override;
+
+ void AddLengthNull(int32_t* length) override;
+
+ Status Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) override;
+
+ void EncodeNull(uint8_t** encoded_bytes) override;
+
+ Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+ MemoryPool* pool) override;
+
+ std::shared_ptr<DataType> type_;
+ int byte_width_;
+};
+
+struct DictionaryKeyEncoder : FixedWidthKeyEncoder {
+ DictionaryKeyEncoder(std::shared_ptr<DataType> type, MemoryPool* pool)
+ : FixedWidthKeyEncoder(std::move(type)), pool_(pool) {}
+
+ Status Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) override;
+
+ Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+ MemoryPool* pool) override;
+
+ MemoryPool* pool_;
+ std::shared_ptr<Array> dictionary_;
+};
+
+template <typename T>
+struct VarLengthKeyEncoder : KeyEncoder {
+ using Offset = typename T::offset_type;
+
+ void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override {
+ if (data.is_array()) {
+ int64_t i = 0;
+ VisitArrayDataInline<T>(
+ *data.array(),
+ [&](util::string_view bytes) {
+ lengths[i++] +=
+ kExtraByteForNull + sizeof(Offset) + static_cast<int32_t>(bytes.size());
+ },
+ [&] { lengths[i++] += kExtraByteForNull + sizeof(Offset); });
+ } else {
+ const Scalar& scalar = *data.scalar();
+ const int32_t buffer_size =
+ scalar.is_valid ? static_cast<int32_t>(UnboxScalar<T>::Unbox(scalar).size())
+ : 0;
+ for (int64_t i = 0; i < batch_length; i++) {
+ lengths[i] += kExtraByteForNull + sizeof(Offset) + buffer_size;
+ }
+ }
+ }
+
+ void AddLengthNull(int32_t* length) override {
+ *length += kExtraByteForNull + sizeof(Offset);
+ }
+
+ Status Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) override {
+ if (data.is_array()) {
+ VisitArrayDataInline<T>(
+ *data.array(),
+ [&](util::string_view bytes) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kValidByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size()));
+ encoded_ptr += sizeof(Offset);
+ memcpy(encoded_ptr, bytes.data(), bytes.size());
+ encoded_ptr += bytes.size();
+ },
+ [&] {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kNullByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(0));
+ encoded_ptr += sizeof(Offset);
+ });
+ } else {
+ const auto& scalar = data.scalar_as<BaseBinaryScalar>();
+ if (scalar.is_valid) {
+ const auto& bytes = *scalar.value;
+ for (int64_t i = 0; i < batch_length; i++) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kValidByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size()));
+ encoded_ptr += sizeof(Offset);
+ memcpy(encoded_ptr, bytes.data(), bytes.size());
+ encoded_ptr += bytes.size();
+ }
+ } else {
+ for (int64_t i = 0; i < batch_length; i++) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kNullByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(0));
+ encoded_ptr += sizeof(Offset);
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ void EncodeNull(uint8_t** encoded_bytes) override {
+ auto& encoded_ptr = *encoded_bytes;
+ *encoded_ptr++ = kNullByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(0));
+ encoded_ptr += sizeof(Offset);
+ }
+
+ Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+ MemoryPool* pool) override {
+ std::shared_ptr<Buffer> null_buf;
+ int32_t null_count;
+ ARROW_RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count));
+
+ Offset length_sum = 0;
+ for (int32_t i = 0; i < length; ++i) {
+ length_sum += util::SafeLoadAs<Offset>(encoded_bytes[i]);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto offset_buf,
+ AllocateBuffer(sizeof(Offset) * (1 + length), pool));
+ ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length_sum));
+
+ auto raw_offsets = reinterpret_cast<Offset*>(offset_buf->mutable_data());
+ auto raw_keys = key_buf->mutable_data();
+
+ Offset current_offset = 0;
+ for (int32_t i = 0; i < length; ++i) {
+ raw_offsets[i] = current_offset;
+
+ auto key_length = util::SafeLoadAs<Offset>(encoded_bytes[i]);
+ encoded_bytes[i] += sizeof(Offset);
+
+ memcpy(raw_keys + current_offset, encoded_bytes[i], key_length);
+ encoded_bytes[i] += key_length;
+
+ current_offset += key_length;
+ }
+ raw_offsets[length] = current_offset;
+
+ return ArrayData::Make(
+ type_, length, {std::move(null_buf), std::move(offset_buf), std::move(key_buf)},
+ null_count);
+ }
+
+ explicit VarLengthKeyEncoder(std::shared_ptr<DataType> type) : type_(std::move(type)) {}
+
+ std::shared_ptr<DataType> type_;
+};
+
+class ARROW_EXPORT RowEncoder {
+ public:
+ static constexpr int kRowIdForNulls() { return -1; }
+
+ void Init(const std::vector<ValueDescr>& column_types, ExecContext* ctx);
+ void Clear();
+ Status EncodeAndAppend(const ExecBatch& batch);
+ Result<ExecBatch> Decode(int64_t num_rows, const int32_t* row_ids);
+
+ inline std::string encoded_row(int32_t i) const {
+ if (i == kRowIdForNulls()) {
+ return std::string(reinterpret_cast<const char*>(encoded_nulls_.data()),
+ encoded_nulls_.size());
+ }
+ int32_t row_length = offsets_[i + 1] - offsets_[i];
+ return std::string(reinterpret_cast<const char*>(bytes_.data() + offsets_[i]),
+ row_length);
+ }
+
+ int32_t num_rows() const {
+ return offsets_.size() == 0 ? 0 : static_cast<int32_t>(offsets_.size() - 1);
+ }
+
+ private:
+ ExecContext* ctx_;
+ std::vector<std::shared_ptr<KeyEncoder>> encoders_;
+ std::vector<int32_t> offsets_;
+ std::vector<uint8_t> bytes_;
+ std::vector<uint8_t> encoded_nulls_;
+};
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow