]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.h
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / kernels / row_encoder.h
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #pragma once
19
20 #include <cstdint>
21
22 #include "arrow/compute/exec.h"
23 #include "arrow/compute/kernels/codegen_internal.h"
24 #include "arrow/visitor_inline.h"
25
26 namespace arrow {
27
28 using internal::checked_cast;
29
30 namespace compute {
31 namespace internal {
32
33 struct KeyEncoder {
34 // the first byte of an encoded key is used to indicate nullity
35 static constexpr bool kExtraByteForNull = true;
36
37 static constexpr uint8_t kNullByte = 1;
38 static constexpr uint8_t kValidByte = 0;
39
40 virtual ~KeyEncoder() = default;
41
42 virtual void AddLength(const Datum&, int64_t batch_length, int32_t* lengths) = 0;
43
44 virtual void AddLengthNull(int32_t* length) = 0;
45
46 virtual Status Encode(const Datum&, int64_t batch_length, uint8_t** encoded_bytes) = 0;
47
48 virtual void EncodeNull(uint8_t** encoded_bytes) = 0;
49
50 virtual Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes,
51 int32_t length, MemoryPool*) = 0;
52
53 // extract the null bitmap from the leading nullity bytes of encoded keys
54 static Status DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes,
55 std::shared_ptr<Buffer>* null_bitmap, int32_t* null_count);
56
57 static bool IsNull(const uint8_t* encoded_bytes) {
58 return encoded_bytes[0] == kNullByte;
59 }
60 };
61
62 struct BooleanKeyEncoder : KeyEncoder {
63 static constexpr int kByteWidth = 1;
64
65 void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override;
66
67 void AddLengthNull(int32_t* length) override;
68
69 Status Encode(const Datum& data, int64_t batch_length,
70 uint8_t** encoded_bytes) override;
71
72 void EncodeNull(uint8_t** encoded_bytes) override;
73
74 Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
75 MemoryPool* pool) override;
76 };
77
78 struct FixedWidthKeyEncoder : KeyEncoder {
79 explicit FixedWidthKeyEncoder(std::shared_ptr<DataType> type)
80 : type_(std::move(type)),
81 byte_width_(checked_cast<const FixedWidthType&>(*type_).bit_width() / 8) {}
82
83 void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override;
84
85 void AddLengthNull(int32_t* length) override;
86
87 Status Encode(const Datum& data, int64_t batch_length,
88 uint8_t** encoded_bytes) override;
89
90 void EncodeNull(uint8_t** encoded_bytes) override;
91
92 Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
93 MemoryPool* pool) override;
94
95 std::shared_ptr<DataType> type_;
96 int byte_width_;
97 };
98
99 struct DictionaryKeyEncoder : FixedWidthKeyEncoder {
100 DictionaryKeyEncoder(std::shared_ptr<DataType> type, MemoryPool* pool)
101 : FixedWidthKeyEncoder(std::move(type)), pool_(pool) {}
102
103 Status Encode(const Datum& data, int64_t batch_length,
104 uint8_t** encoded_bytes) override;
105
106 Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
107 MemoryPool* pool) override;
108
109 MemoryPool* pool_;
110 std::shared_ptr<Array> dictionary_;
111 };
112
113 template <typename T>
114 struct VarLengthKeyEncoder : KeyEncoder {
115 using Offset = typename T::offset_type;
116
117 void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override {
118 if (data.is_array()) {
119 int64_t i = 0;
120 VisitArrayDataInline<T>(
121 *data.array(),
122 [&](util::string_view bytes) {
123 lengths[i++] +=
124 kExtraByteForNull + sizeof(Offset) + static_cast<int32_t>(bytes.size());
125 },
126 [&] { lengths[i++] += kExtraByteForNull + sizeof(Offset); });
127 } else {
128 const Scalar& scalar = *data.scalar();
129 const int32_t buffer_size =
130 scalar.is_valid ? static_cast<int32_t>(UnboxScalar<T>::Unbox(scalar).size())
131 : 0;
132 for (int64_t i = 0; i < batch_length; i++) {
133 lengths[i] += kExtraByteForNull + sizeof(Offset) + buffer_size;
134 }
135 }
136 }
137
138 void AddLengthNull(int32_t* length) override {
139 *length += kExtraByteForNull + sizeof(Offset);
140 }
141
142 Status Encode(const Datum& data, int64_t batch_length,
143 uint8_t** encoded_bytes) override {
144 if (data.is_array()) {
145 VisitArrayDataInline<T>(
146 *data.array(),
147 [&](util::string_view bytes) {
148 auto& encoded_ptr = *encoded_bytes++;
149 *encoded_ptr++ = kValidByte;
150 util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size()));
151 encoded_ptr += sizeof(Offset);
152 memcpy(encoded_ptr, bytes.data(), bytes.size());
153 encoded_ptr += bytes.size();
154 },
155 [&] {
156 auto& encoded_ptr = *encoded_bytes++;
157 *encoded_ptr++ = kNullByte;
158 util::SafeStore(encoded_ptr, static_cast<Offset>(0));
159 encoded_ptr += sizeof(Offset);
160 });
161 } else {
162 const auto& scalar = data.scalar_as<BaseBinaryScalar>();
163 if (scalar.is_valid) {
164 const auto& bytes = *scalar.value;
165 for (int64_t i = 0; i < batch_length; i++) {
166 auto& encoded_ptr = *encoded_bytes++;
167 *encoded_ptr++ = kValidByte;
168 util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size()));
169 encoded_ptr += sizeof(Offset);
170 memcpy(encoded_ptr, bytes.data(), bytes.size());
171 encoded_ptr += bytes.size();
172 }
173 } else {
174 for (int64_t i = 0; i < batch_length; i++) {
175 auto& encoded_ptr = *encoded_bytes++;
176 *encoded_ptr++ = kNullByte;
177 util::SafeStore(encoded_ptr, static_cast<Offset>(0));
178 encoded_ptr += sizeof(Offset);
179 }
180 }
181 }
182 return Status::OK();
183 }
184
185 void EncodeNull(uint8_t** encoded_bytes) override {
186 auto& encoded_ptr = *encoded_bytes;
187 *encoded_ptr++ = kNullByte;
188 util::SafeStore(encoded_ptr, static_cast<Offset>(0));
189 encoded_ptr += sizeof(Offset);
190 }
191
192 Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
193 MemoryPool* pool) override {
194 std::shared_ptr<Buffer> null_buf;
195 int32_t null_count;
196 ARROW_RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count));
197
198 Offset length_sum = 0;
199 for (int32_t i = 0; i < length; ++i) {
200 length_sum += util::SafeLoadAs<Offset>(encoded_bytes[i]);
201 }
202
203 ARROW_ASSIGN_OR_RAISE(auto offset_buf,
204 AllocateBuffer(sizeof(Offset) * (1 + length), pool));
205 ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length_sum));
206
207 auto raw_offsets = reinterpret_cast<Offset*>(offset_buf->mutable_data());
208 auto raw_keys = key_buf->mutable_data();
209
210 Offset current_offset = 0;
211 for (int32_t i = 0; i < length; ++i) {
212 raw_offsets[i] = current_offset;
213
214 auto key_length = util::SafeLoadAs<Offset>(encoded_bytes[i]);
215 encoded_bytes[i] += sizeof(Offset);
216
217 memcpy(raw_keys + current_offset, encoded_bytes[i], key_length);
218 encoded_bytes[i] += key_length;
219
220 current_offset += key_length;
221 }
222 raw_offsets[length] = current_offset;
223
224 return ArrayData::Make(
225 type_, length, {std::move(null_buf), std::move(offset_buf), std::move(key_buf)},
226 null_count);
227 }
228
229 explicit VarLengthKeyEncoder(std::shared_ptr<DataType> type) : type_(std::move(type)) {}
230
231 std::shared_ptr<DataType> type_;
232 };
233
234 class ARROW_EXPORT RowEncoder {
235 public:
236 static constexpr int kRowIdForNulls() { return -1; }
237
238 void Init(const std::vector<ValueDescr>& column_types, ExecContext* ctx);
239 void Clear();
240 Status EncodeAndAppend(const ExecBatch& batch);
241 Result<ExecBatch> Decode(int64_t num_rows, const int32_t* row_ids);
242
243 inline std::string encoded_row(int32_t i) const {
244 if (i == kRowIdForNulls()) {
245 return std::string(reinterpret_cast<const char*>(encoded_nulls_.data()),
246 encoded_nulls_.size());
247 }
248 int32_t row_length = offsets_[i + 1] - offsets_[i];
249 return std::string(reinterpret_cast<const char*>(bytes_.data() + offsets_[i]),
250 row_length);
251 }
252
253 int32_t num_rows() const {
254 return offsets_.size() == 0 ? 0 : static_cast<int32_t>(offsets_.size() - 1);
255 }
256
257 private:
258 ExecContext* ctx_;
259 std::vector<std::shared_ptr<KeyEncoder>> encoders_;
260 std::vector<int32_t> offsets_;
261 std::vector<uint8_t> bytes_;
262 std::vector<uint8_t> encoded_nulls_;
263 };
264
265 } // namespace internal
266 } // namespace compute
267 } // namespace arrow