]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #pragma once | |
19 | ||
20 | #include <cstdint> | |
21 | ||
22 | #include "arrow/compute/exec.h" | |
23 | #include "arrow/compute/kernels/codegen_internal.h" | |
24 | #include "arrow/visitor_inline.h" | |
25 | ||
26 | namespace arrow { | |
27 | ||
28 | using internal::checked_cast; | |
29 | ||
30 | namespace compute { | |
31 | namespace internal { | |
32 | ||
33 | struct KeyEncoder { | |
34 | // the first byte of an encoded key is used to indicate nullity | |
35 | static constexpr bool kExtraByteForNull = true; | |
36 | ||
37 | static constexpr uint8_t kNullByte = 1; | |
38 | static constexpr uint8_t kValidByte = 0; | |
39 | ||
40 | virtual ~KeyEncoder() = default; | |
41 | ||
42 | virtual void AddLength(const Datum&, int64_t batch_length, int32_t* lengths) = 0; | |
43 | ||
44 | virtual void AddLengthNull(int32_t* length) = 0; | |
45 | ||
46 | virtual Status Encode(const Datum&, int64_t batch_length, uint8_t** encoded_bytes) = 0; | |
47 | ||
48 | virtual void EncodeNull(uint8_t** encoded_bytes) = 0; | |
49 | ||
50 | virtual Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, | |
51 | int32_t length, MemoryPool*) = 0; | |
52 | ||
53 | // extract the null bitmap from the leading nullity bytes of encoded keys | |
54 | static Status DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes, | |
55 | std::shared_ptr<Buffer>* null_bitmap, int32_t* null_count); | |
56 | ||
57 | static bool IsNull(const uint8_t* encoded_bytes) { | |
58 | return encoded_bytes[0] == kNullByte; | |
59 | } | |
60 | }; | |
61 | ||
62 | struct BooleanKeyEncoder : KeyEncoder { | |
63 | static constexpr int kByteWidth = 1; | |
64 | ||
65 | void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override; | |
66 | ||
67 | void AddLengthNull(int32_t* length) override; | |
68 | ||
69 | Status Encode(const Datum& data, int64_t batch_length, | |
70 | uint8_t** encoded_bytes) override; | |
71 | ||
72 | void EncodeNull(uint8_t** encoded_bytes) override; | |
73 | ||
74 | Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length, | |
75 | MemoryPool* pool) override; | |
76 | }; | |
77 | ||
78 | struct FixedWidthKeyEncoder : KeyEncoder { | |
79 | explicit FixedWidthKeyEncoder(std::shared_ptr<DataType> type) | |
80 | : type_(std::move(type)), | |
81 | byte_width_(checked_cast<const FixedWidthType&>(*type_).bit_width() / 8) {} | |
82 | ||
83 | void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override; | |
84 | ||
85 | void AddLengthNull(int32_t* length) override; | |
86 | ||
87 | Status Encode(const Datum& data, int64_t batch_length, | |
88 | uint8_t** encoded_bytes) override; | |
89 | ||
90 | void EncodeNull(uint8_t** encoded_bytes) override; | |
91 | ||
92 | Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length, | |
93 | MemoryPool* pool) override; | |
94 | ||
95 | std::shared_ptr<DataType> type_; | |
96 | int byte_width_; | |
97 | }; | |
98 | ||
99 | struct DictionaryKeyEncoder : FixedWidthKeyEncoder { | |
100 | DictionaryKeyEncoder(std::shared_ptr<DataType> type, MemoryPool* pool) | |
101 | : FixedWidthKeyEncoder(std::move(type)), pool_(pool) {} | |
102 | ||
103 | Status Encode(const Datum& data, int64_t batch_length, | |
104 | uint8_t** encoded_bytes) override; | |
105 | ||
106 | Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length, | |
107 | MemoryPool* pool) override; | |
108 | ||
109 | MemoryPool* pool_; | |
110 | std::shared_ptr<Array> dictionary_; | |
111 | }; | |
112 | ||
113 | template <typename T> | |
114 | struct VarLengthKeyEncoder : KeyEncoder { | |
115 | using Offset = typename T::offset_type; | |
116 | ||
117 | void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override { | |
118 | if (data.is_array()) { | |
119 | int64_t i = 0; | |
120 | VisitArrayDataInline<T>( | |
121 | *data.array(), | |
122 | [&](util::string_view bytes) { | |
123 | lengths[i++] += | |
124 | kExtraByteForNull + sizeof(Offset) + static_cast<int32_t>(bytes.size()); | |
125 | }, | |
126 | [&] { lengths[i++] += kExtraByteForNull + sizeof(Offset); }); | |
127 | } else { | |
128 | const Scalar& scalar = *data.scalar(); | |
129 | const int32_t buffer_size = | |
130 | scalar.is_valid ? static_cast<int32_t>(UnboxScalar<T>::Unbox(scalar).size()) | |
131 | : 0; | |
132 | for (int64_t i = 0; i < batch_length; i++) { | |
133 | lengths[i] += kExtraByteForNull + sizeof(Offset) + buffer_size; | |
134 | } | |
135 | } | |
136 | } | |
137 | ||
138 | void AddLengthNull(int32_t* length) override { | |
139 | *length += kExtraByteForNull + sizeof(Offset); | |
140 | } | |
141 | ||
142 | Status Encode(const Datum& data, int64_t batch_length, | |
143 | uint8_t** encoded_bytes) override { | |
144 | if (data.is_array()) { | |
145 | VisitArrayDataInline<T>( | |
146 | *data.array(), | |
147 | [&](util::string_view bytes) { | |
148 | auto& encoded_ptr = *encoded_bytes++; | |
149 | *encoded_ptr++ = kValidByte; | |
150 | util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size())); | |
151 | encoded_ptr += sizeof(Offset); | |
152 | memcpy(encoded_ptr, bytes.data(), bytes.size()); | |
153 | encoded_ptr += bytes.size(); | |
154 | }, | |
155 | [&] { | |
156 | auto& encoded_ptr = *encoded_bytes++; | |
157 | *encoded_ptr++ = kNullByte; | |
158 | util::SafeStore(encoded_ptr, static_cast<Offset>(0)); | |
159 | encoded_ptr += sizeof(Offset); | |
160 | }); | |
161 | } else { | |
162 | const auto& scalar = data.scalar_as<BaseBinaryScalar>(); | |
163 | if (scalar.is_valid) { | |
164 | const auto& bytes = *scalar.value; | |
165 | for (int64_t i = 0; i < batch_length; i++) { | |
166 | auto& encoded_ptr = *encoded_bytes++; | |
167 | *encoded_ptr++ = kValidByte; | |
168 | util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size())); | |
169 | encoded_ptr += sizeof(Offset); | |
170 | memcpy(encoded_ptr, bytes.data(), bytes.size()); | |
171 | encoded_ptr += bytes.size(); | |
172 | } | |
173 | } else { | |
174 | for (int64_t i = 0; i < batch_length; i++) { | |
175 | auto& encoded_ptr = *encoded_bytes++; | |
176 | *encoded_ptr++ = kNullByte; | |
177 | util::SafeStore(encoded_ptr, static_cast<Offset>(0)); | |
178 | encoded_ptr += sizeof(Offset); | |
179 | } | |
180 | } | |
181 | } | |
182 | return Status::OK(); | |
183 | } | |
184 | ||
185 | void EncodeNull(uint8_t** encoded_bytes) override { | |
186 | auto& encoded_ptr = *encoded_bytes; | |
187 | *encoded_ptr++ = kNullByte; | |
188 | util::SafeStore(encoded_ptr, static_cast<Offset>(0)); | |
189 | encoded_ptr += sizeof(Offset); | |
190 | } | |
191 | ||
192 | Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length, | |
193 | MemoryPool* pool) override { | |
194 | std::shared_ptr<Buffer> null_buf; | |
195 | int32_t null_count; | |
196 | ARROW_RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count)); | |
197 | ||
198 | Offset length_sum = 0; | |
199 | for (int32_t i = 0; i < length; ++i) { | |
200 | length_sum += util::SafeLoadAs<Offset>(encoded_bytes[i]); | |
201 | } | |
202 | ||
203 | ARROW_ASSIGN_OR_RAISE(auto offset_buf, | |
204 | AllocateBuffer(sizeof(Offset) * (1 + length), pool)); | |
205 | ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length_sum)); | |
206 | ||
207 | auto raw_offsets = reinterpret_cast<Offset*>(offset_buf->mutable_data()); | |
208 | auto raw_keys = key_buf->mutable_data(); | |
209 | ||
210 | Offset current_offset = 0; | |
211 | for (int32_t i = 0; i < length; ++i) { | |
212 | raw_offsets[i] = current_offset; | |
213 | ||
214 | auto key_length = util::SafeLoadAs<Offset>(encoded_bytes[i]); | |
215 | encoded_bytes[i] += sizeof(Offset); | |
216 | ||
217 | memcpy(raw_keys + current_offset, encoded_bytes[i], key_length); | |
218 | encoded_bytes[i] += key_length; | |
219 | ||
220 | current_offset += key_length; | |
221 | } | |
222 | raw_offsets[length] = current_offset; | |
223 | ||
224 | return ArrayData::Make( | |
225 | type_, length, {std::move(null_buf), std::move(offset_buf), std::move(key_buf)}, | |
226 | null_count); | |
227 | } | |
228 | ||
229 | explicit VarLengthKeyEncoder(std::shared_ptr<DataType> type) : type_(std::move(type)) {} | |
230 | ||
231 | std::shared_ptr<DataType> type_; | |
232 | }; | |
233 | ||
234 | class ARROW_EXPORT RowEncoder { | |
235 | public: | |
236 | static constexpr int kRowIdForNulls() { return -1; } | |
237 | ||
238 | void Init(const std::vector<ValueDescr>& column_types, ExecContext* ctx); | |
239 | void Clear(); | |
240 | Status EncodeAndAppend(const ExecBatch& batch); | |
241 | Result<ExecBatch> Decode(int64_t num_rows, const int32_t* row_ids); | |
242 | ||
243 | inline std::string encoded_row(int32_t i) const { | |
244 | if (i == kRowIdForNulls()) { | |
245 | return std::string(reinterpret_cast<const char*>(encoded_nulls_.data()), | |
246 | encoded_nulls_.size()); | |
247 | } | |
248 | int32_t row_length = offsets_[i + 1] - offsets_[i]; | |
249 | return std::string(reinterpret_cast<const char*>(bytes_.data() + offsets_[i]), | |
250 | row_length); | |
251 | } | |
252 | ||
253 | int32_t num_rows() const { | |
254 | return offsets_.size() == 0 ? 0 : static_cast<int32_t>(offsets_.size() - 1); | |
255 | } | |
256 | ||
257 | private: | |
258 | ExecContext* ctx_; | |
259 | std::vector<std::shared_ptr<KeyEncoder>> encoders_; | |
260 | std::vector<int32_t> offsets_; | |
261 | std::vector<uint8_t> bytes_; | |
262 | std::vector<uint8_t> encoded_nulls_; | |
263 | }; | |
264 | ||
265 | } // namespace internal | |
266 | } // namespace compute | |
267 | } // namespace arrow |