]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/ipc/writer.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / ipc / writer.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/ipc/writer.h"
19
20 #include <algorithm>
21 #include <cstdint>
22 #include <cstring>
23 #include <limits>
24 #include <sstream>
25 #include <string>
26 #include <type_traits>
27 #include <unordered_map>
28 #include <utility>
29 #include <vector>
30
31 #include "arrow/array.h"
32 #include "arrow/buffer.h"
33 #include "arrow/device.h"
34 #include "arrow/extension_type.h"
35 #include "arrow/io/interfaces.h"
36 #include "arrow/io/memory.h"
37 #include "arrow/ipc/dictionary.h"
38 #include "arrow/ipc/message.h"
39 #include "arrow/ipc/metadata_internal.h"
40 #include "arrow/ipc/util.h"
41 #include "arrow/record_batch.h"
42 #include "arrow/result_internal.h"
43 #include "arrow/sparse_tensor.h"
44 #include "arrow/status.h"
45 #include "arrow/table.h"
46 #include "arrow/type.h"
47 #include "arrow/type_traits.h"
48 #include "arrow/util/bit_util.h"
49 #include "arrow/util/bitmap_ops.h"
50 #include "arrow/util/checked_cast.h"
51 #include "arrow/util/compression.h"
52 #include "arrow/util/endian.h"
53 #include "arrow/util/key_value_metadata.h"
54 #include "arrow/util/logging.h"
55 #include "arrow/util/make_unique.h"
56 #include "arrow/util/parallel.h"
57 #include "arrow/visitor_inline.h"
58
59 namespace arrow {
60
61 using internal::checked_cast;
62 using internal::checked_pointer_cast;
63 using internal::CopyBitmap;
64 using internal::GetByteWidth;
65
66 namespace ipc {
67
68 using internal::FileBlock;
69 using internal::kArrowMagicBytes;
70
71 namespace {
72
73 bool HasNestedDict(const ArrayData& data) {
74 if (data.type->id() == Type::DICTIONARY) {
75 return true;
76 }
77 for (const auto& child : data.child_data) {
78 if (HasNestedDict(*child)) {
79 return true;
80 }
81 }
82 return false;
83 }
84
85 Status GetTruncatedBitmap(int64_t offset, int64_t length,
86 const std::shared_ptr<Buffer> input, MemoryPool* pool,
87 std::shared_ptr<Buffer>* buffer) {
88 if (!input) {
89 *buffer = input;
90 return Status::OK();
91 }
92 int64_t min_length = PaddedLength(BitUtil::BytesForBits(length));
93 if (offset != 0 || min_length < input->size()) {
94 // With a sliced array / non-zero offset, we must copy the bitmap
95 ARROW_ASSIGN_OR_RAISE(*buffer, CopyBitmap(pool, input->data(), offset, length));
96 } else {
97 *buffer = input;
98 }
99 return Status::OK();
100 }
101
102 Status GetTruncatedBuffer(int64_t offset, int64_t length, int32_t byte_width,
103 const std::shared_ptr<Buffer> input, MemoryPool* pool,
104 std::shared_ptr<Buffer>* buffer) {
105 if (!input) {
106 *buffer = input;
107 return Status::OK();
108 }
109 int64_t padded_length = PaddedLength(length * byte_width);
110 if (offset != 0 || padded_length < input->size()) {
111 *buffer =
112 SliceBuffer(input, offset * byte_width, std::min(padded_length, input->size()));
113 } else {
114 *buffer = input;
115 }
116 return Status::OK();
117 }
118
119 static inline bool NeedTruncate(int64_t offset, const Buffer* buffer,
120 int64_t min_length) {
121 // buffer can be NULL
122 if (buffer == nullptr) {
123 return false;
124 }
125 return offset != 0 || min_length < buffer->size();
126 }
127
128 class RecordBatchSerializer {
129 public:
130 RecordBatchSerializer(int64_t buffer_start_offset, const IpcWriteOptions& options,
131 IpcPayload* out)
132 : out_(out),
133 options_(options),
134 max_recursion_depth_(options.max_recursion_depth),
135 buffer_start_offset_(buffer_start_offset) {
136 DCHECK_GT(max_recursion_depth_, 0);
137 }
138
139 virtual ~RecordBatchSerializer() = default;
140
141 Status VisitArray(const Array& arr) {
142 static std::shared_ptr<Buffer> kNullBuffer = std::make_shared<Buffer>(nullptr, 0);
143
144 if (max_recursion_depth_ <= 0) {
145 return Status::Invalid("Max recursion depth reached");
146 }
147
148 if (!options_.allow_64bit && arr.length() > std::numeric_limits<int32_t>::max()) {
149 return Status::CapacityError("Cannot write arrays larger than 2^31 - 1 in length");
150 }
151
152 // push back all common elements
153 field_nodes_.push_back({arr.length(), arr.null_count(), 0});
154
155 // In V4, null types have no validity bitmap
156 // In V5 and later, null and union types have no validity bitmap
157 if (internal::HasValidityBitmap(arr.type_id(), options_.metadata_version)) {
158 if (arr.null_count() > 0) {
159 std::shared_ptr<Buffer> bitmap;
160 RETURN_NOT_OK(GetTruncatedBitmap(arr.offset(), arr.length(), arr.null_bitmap(),
161 options_.memory_pool, &bitmap));
162 out_->body_buffers.emplace_back(bitmap);
163 } else {
164 // Push a dummy zero-length buffer, not to be copied
165 out_->body_buffers.emplace_back(kNullBuffer);
166 }
167 }
168 return VisitType(arr);
169 }
170
171 // Override this for writing dictionary metadata
172 virtual Status SerializeMetadata(int64_t num_rows) {
173 return WriteRecordBatchMessage(num_rows, out_->body_length, custom_metadata_,
174 field_nodes_, buffer_meta_, options_, &out_->metadata);
175 }
176
177 void AppendCustomMetadata(const std::string& key, const std::string& value) {
178 if (!custom_metadata_) {
179 custom_metadata_ = std::make_shared<KeyValueMetadata>();
180 }
181 custom_metadata_->Append(key, value);
182 }
183
184 Status CompressBuffer(const Buffer& buffer, util::Codec* codec,
185 std::shared_ptr<Buffer>* out) {
186 // Convert buffer to uncompressed-length-prefixed compressed buffer
187 int64_t maximum_length = codec->MaxCompressedLen(buffer.size(), buffer.data());
188 ARROW_ASSIGN_OR_RAISE(auto result, AllocateBuffer(maximum_length + sizeof(int64_t)));
189
190 int64_t actual_length;
191 ARROW_ASSIGN_OR_RAISE(actual_length,
192 codec->Compress(buffer.size(), buffer.data(), maximum_length,
193 result->mutable_data() + sizeof(int64_t)));
194 *reinterpret_cast<int64_t*>(result->mutable_data()) =
195 BitUtil::ToLittleEndian(buffer.size());
196 *out = SliceBuffer(std::move(result), /*offset=*/0, actual_length + sizeof(int64_t));
197 return Status::OK();
198 }
199
200 Status CompressBodyBuffers() {
201 RETURN_NOT_OK(
202 internal::CheckCompressionSupported(options_.codec->compression_type()));
203
204 auto CompressOne = [&](size_t i) {
205 if (out_->body_buffers[i]->size() > 0) {
206 RETURN_NOT_OK(CompressBuffer(*out_->body_buffers[i], options_.codec.get(),
207 &out_->body_buffers[i]));
208 }
209 return Status::OK();
210 };
211
212 return ::arrow::internal::OptionalParallelFor(
213 options_.use_threads, static_cast<int>(out_->body_buffers.size()), CompressOne);
214 }
215
216 Status Assemble(const RecordBatch& batch) {
217 if (field_nodes_.size() > 0) {
218 field_nodes_.clear();
219 buffer_meta_.clear();
220 out_->body_buffers.clear();
221 }
222
223 // Perform depth-first traversal of the row-batch
224 for (int i = 0; i < batch.num_columns(); ++i) {
225 RETURN_NOT_OK(VisitArray(*batch.column(i)));
226 }
227
228 if (options_.codec != nullptr) {
229 RETURN_NOT_OK(CompressBodyBuffers());
230 }
231
232 // The position for the start of a buffer relative to the passed frame of
233 // reference. May be 0 or some other position in an address space
234 int64_t offset = buffer_start_offset_;
235
236 buffer_meta_.reserve(out_->body_buffers.size());
237
238 // Construct the buffer metadata for the record batch header
239 for (const auto& buffer : out_->body_buffers) {
240 int64_t size = 0;
241 int64_t padding = 0;
242
243 // The buffer might be null if we are handling zero row lengths.
244 if (buffer) {
245 size = buffer->size();
246 padding = BitUtil::RoundUpToMultipleOf8(size) - size;
247 }
248
249 buffer_meta_.push_back({offset, size});
250 offset += size + padding;
251 }
252
253 out_->body_length = offset - buffer_start_offset_;
254 DCHECK(BitUtil::IsMultipleOf8(out_->body_length));
255
256 // Now that we have computed the locations of all of the buffers in shared
257 // memory, the data header can be converted to a flatbuffer and written out
258 //
259 // Note: The memory written here is prefixed by the size of the flatbuffer
260 // itself as an int32_t.
261 return SerializeMetadata(batch.num_rows());
262 }
263
264 template <typename ArrayType>
265 Status GetZeroBasedValueOffsets(const ArrayType& array,
266 std::shared_ptr<Buffer>* value_offsets) {
267 // Share slicing logic between ListArray, BinaryArray and LargeBinaryArray
268 using offset_type = typename ArrayType::offset_type;
269
270 auto offsets = array.value_offsets();
271
272 int64_t required_bytes = sizeof(offset_type) * (array.length() + 1);
273 if (array.offset() != 0) {
274 // If we have a non-zero offset, then the value offsets do not start at
275 // zero. We must a) create a new offsets array with shifted offsets and
276 // b) slice the values array accordingly
277
278 ARROW_ASSIGN_OR_RAISE(auto shifted_offsets,
279 AllocateBuffer(required_bytes, options_.memory_pool));
280
281 offset_type* dest_offsets =
282 reinterpret_cast<offset_type*>(shifted_offsets->mutable_data());
283 const offset_type start_offset = array.value_offset(0);
284
285 for (int i = 0; i < array.length(); ++i) {
286 dest_offsets[i] = array.value_offset(i) - start_offset;
287 }
288 // Final offset
289 dest_offsets[array.length()] = array.value_offset(array.length()) - start_offset;
290 offsets = std::move(shifted_offsets);
291 } else {
292 // ARROW-6046: Slice offsets to used extent, in case we have a truncated
293 // slice
294 if (offsets != nullptr && offsets->size() > required_bytes) {
295 offsets = SliceBuffer(offsets, 0, required_bytes);
296 }
297 }
298 *value_offsets = std::move(offsets);
299 return Status::OK();
300 }
301
302 Status Visit(const BooleanArray& array) {
303 std::shared_ptr<Buffer> data;
304 RETURN_NOT_OK(GetTruncatedBitmap(array.offset(), array.length(), array.values(),
305 options_.memory_pool, &data));
306 out_->body_buffers.emplace_back(data);
307 return Status::OK();
308 }
309
310 Status Visit(const NullArray& array) { return Status::OK(); }
311
312 template <typename T>
313 typename std::enable_if<is_number_type<typename T::TypeClass>::value ||
314 is_temporal_type<typename T::TypeClass>::value ||
315 is_fixed_size_binary_type<typename T::TypeClass>::value,
316 Status>::type
317 Visit(const T& array) {
318 std::shared_ptr<Buffer> data = array.values();
319
320 const int64_t type_width = GetByteWidth(*array.type());
321 int64_t min_length = PaddedLength(array.length() * type_width);
322
323 if (NeedTruncate(array.offset(), data.get(), min_length)) {
324 // Non-zero offset, slice the buffer
325 const int64_t byte_offset = array.offset() * type_width;
326
327 // Send padding if it's available
328 const int64_t buffer_length =
329 std::min(BitUtil::RoundUpToMultipleOf8(array.length() * type_width),
330 data->size() - byte_offset);
331 data = SliceBuffer(data, byte_offset, buffer_length);
332 }
333 out_->body_buffers.emplace_back(data);
334 return Status::OK();
335 }
336
337 template <typename T>
338 enable_if_base_binary<typename T::TypeClass, Status> Visit(const T& array) {
339 std::shared_ptr<Buffer> value_offsets;
340 RETURN_NOT_OK(GetZeroBasedValueOffsets<T>(array, &value_offsets));
341 auto data = array.value_data();
342
343 int64_t total_data_bytes = 0;
344 if (value_offsets) {
345 total_data_bytes = array.value_offset(array.length()) - array.value_offset(0);
346 }
347 if (NeedTruncate(array.offset(), data.get(), total_data_bytes)) {
348 // Slice the data buffer to include only the range we need now
349 const int64_t start_offset = array.value_offset(0);
350 const int64_t slice_length =
351 std::min(PaddedLength(total_data_bytes), data->size() - start_offset);
352 data = SliceBuffer(data, start_offset, slice_length);
353 }
354
355 out_->body_buffers.emplace_back(value_offsets);
356 out_->body_buffers.emplace_back(data);
357 return Status::OK();
358 }
359
360 template <typename T>
361 enable_if_base_list<typename T::TypeClass, Status> Visit(const T& array) {
362 using offset_type = typename T::offset_type;
363
364 std::shared_ptr<Buffer> value_offsets;
365 RETURN_NOT_OK(GetZeroBasedValueOffsets<T>(array, &value_offsets));
366 out_->body_buffers.emplace_back(value_offsets);
367
368 --max_recursion_depth_;
369 std::shared_ptr<Array> values = array.values();
370
371 offset_type values_offset = 0;
372 offset_type values_length = 0;
373 if (value_offsets) {
374 values_offset = array.value_offset(0);
375 values_length = array.value_offset(array.length()) - values_offset;
376 }
377
378 if (array.offset() != 0 || values_length < values->length()) {
379 // Must also slice the values
380 values = values->Slice(values_offset, values_length);
381 }
382 RETURN_NOT_OK(VisitArray(*values));
383 ++max_recursion_depth_;
384 return Status::OK();
385 }
386
387 Status Visit(const FixedSizeListArray& array) {
388 --max_recursion_depth_;
389 auto size = array.list_type()->list_size();
390 auto values = array.values()->Slice(array.offset() * size, array.length() * size);
391
392 RETURN_NOT_OK(VisitArray(*values));
393 ++max_recursion_depth_;
394 return Status::OK();
395 }
396
397 Status Visit(const StructArray& array) {
398 --max_recursion_depth_;
399 for (int i = 0; i < array.num_fields(); ++i) {
400 std::shared_ptr<Array> field = array.field(i);
401 RETURN_NOT_OK(VisitArray(*field));
402 }
403 ++max_recursion_depth_;
404 return Status::OK();
405 }
406
407 Status Visit(const SparseUnionArray& array) {
408 const int64_t offset = array.offset();
409 const int64_t length = array.length();
410
411 std::shared_ptr<Buffer> type_codes;
412 RETURN_NOT_OK(GetTruncatedBuffer(
413 offset, length, static_cast<int32_t>(sizeof(UnionArray::type_code_t)),
414 array.type_codes(), options_.memory_pool, &type_codes));
415 out_->body_buffers.emplace_back(type_codes);
416
417 --max_recursion_depth_;
418 for (int i = 0; i < array.num_fields(); ++i) {
419 // Sparse union, slicing is done for us by field()
420 RETURN_NOT_OK(VisitArray(*array.field(i)));
421 }
422 ++max_recursion_depth_;
423 return Status::OK();
424 }
425
426 Status Visit(const DenseUnionArray& array) {
427 const int64_t offset = array.offset();
428 const int64_t length = array.length();
429
430 std::shared_ptr<Buffer> type_codes;
431 RETURN_NOT_OK(GetTruncatedBuffer(
432 offset, length, static_cast<int32_t>(sizeof(UnionArray::type_code_t)),
433 array.type_codes(), options_.memory_pool, &type_codes));
434 out_->body_buffers.emplace_back(type_codes);
435
436 --max_recursion_depth_;
437 const auto& type = checked_cast<const UnionType&>(*array.type());
438
439 std::shared_ptr<Buffer> value_offsets;
440 RETURN_NOT_OK(
441 GetTruncatedBuffer(offset, length, static_cast<int32_t>(sizeof(int32_t)),
442 array.value_offsets(), options_.memory_pool, &value_offsets));
443
444 // The Union type codes are not necessary 0-indexed
445 int8_t max_code = 0;
446 for (int8_t code : type.type_codes()) {
447 if (code > max_code) {
448 max_code = code;
449 }
450 }
451
452 // Allocate an array of child offsets. Set all to -1 to indicate that we
453 // haven't observed a first occurrence of a particular child yet
454 std::vector<int32_t> child_offsets(max_code + 1, -1);
455 std::vector<int32_t> child_lengths(max_code + 1, 0);
456
457 if (offset != 0) {
458 // This is an unpleasant case. Because the offsets are different for
459 // each child array, when we have a sliced array, we need to "rebase"
460 // the value_offsets for each array
461
462 const int32_t* unshifted_offsets = array.raw_value_offsets();
463 const int8_t* type_codes = array.raw_type_codes();
464
465 // Allocate the shifted offsets
466 ARROW_ASSIGN_OR_RAISE(
467 auto shifted_offsets_buffer,
468 AllocateBuffer(length * sizeof(int32_t), options_.memory_pool));
469 int32_t* shifted_offsets =
470 reinterpret_cast<int32_t*>(shifted_offsets_buffer->mutable_data());
471
472 // Offsets may not be ascending, so we need to find out the start offset
473 // for each child
474 for (int64_t i = 0; i < length; ++i) {
475 const uint8_t code = type_codes[i];
476 if (child_offsets[code] == -1) {
477 child_offsets[code] = unshifted_offsets[i];
478 } else {
479 child_offsets[code] = std::min(child_offsets[code], unshifted_offsets[i]);
480 }
481 }
482
483 // Now compute shifted offsets by subtracting child offset
484 for (int64_t i = 0; i < length; ++i) {
485 const int8_t code = type_codes[i];
486 shifted_offsets[i] = unshifted_offsets[i] - child_offsets[code];
487 // Update the child length to account for observed value
488 child_lengths[code] = std::max(child_lengths[code], shifted_offsets[i] + 1);
489 }
490
491 value_offsets = std::move(shifted_offsets_buffer);
492 }
493 out_->body_buffers.emplace_back(value_offsets);
494
495 // Visit children and slice accordingly
496 for (int i = 0; i < type.num_fields(); ++i) {
497 std::shared_ptr<Array> child = array.field(i);
498
499 // TODO: ARROW-809, for sliced unions, tricky to know how much to
500 // truncate the children. For now, we are truncating the children to be
501 // no longer than the parent union.
502 if (offset != 0) {
503 const int8_t code = type.type_codes()[i];
504 const int64_t child_offset = child_offsets[code];
505 const int64_t child_length = child_lengths[code];
506
507 if (child_offset > 0) {
508 child = child->Slice(child_offset, child_length);
509 } else if (child_length < child->length()) {
510 // This case includes when child is not encountered at all
511 child = child->Slice(0, child_length);
512 }
513 }
514 RETURN_NOT_OK(VisitArray(*child));
515 }
516 ++max_recursion_depth_;
517 return Status::OK();
518 }
519
520 Status Visit(const DictionaryArray& array) {
521 // Dictionary written out separately. Slice offset contained in the indices
522 return VisitType(*array.indices());
523 }
524
525 Status Visit(const ExtensionArray& array) { return VisitType(*array.storage()); }
526
527 Status VisitType(const Array& values) { return VisitArrayInline(values, this); }
528
529 protected:
530 // Destination for output buffers
531 IpcPayload* out_;
532
533 std::shared_ptr<KeyValueMetadata> custom_metadata_;
534
535 std::vector<internal::FieldMetadata> field_nodes_;
536 std::vector<internal::BufferMetadata> buffer_meta_;
537
538 const IpcWriteOptions& options_;
539 int64_t max_recursion_depth_;
540 int64_t buffer_start_offset_;
541 };
542
543 class DictionarySerializer : public RecordBatchSerializer {
544 public:
545 DictionarySerializer(int64_t dictionary_id, bool is_delta, int64_t buffer_start_offset,
546 const IpcWriteOptions& options, IpcPayload* out)
547 : RecordBatchSerializer(buffer_start_offset, options, out),
548 dictionary_id_(dictionary_id),
549 is_delta_(is_delta) {}
550
551 Status SerializeMetadata(int64_t num_rows) override {
552 return WriteDictionaryMessage(dictionary_id_, is_delta_, num_rows, out_->body_length,
553 custom_metadata_, field_nodes_, buffer_meta_, options_,
554 &out_->metadata);
555 }
556
557 Status Assemble(const std::shared_ptr<Array>& dictionary) {
558 // Make a dummy record batch. A bit tedious as we have to make a schema
559 auto schema = arrow::schema({arrow::field("dictionary", dictionary->type())});
560 auto batch = RecordBatch::Make(std::move(schema), dictionary->length(), {dictionary});
561 return RecordBatchSerializer::Assemble(*batch);
562 }
563
564 private:
565 int64_t dictionary_id_;
566 bool is_delta_;
567 };
568
569 } // namespace
570
571 Status WriteIpcPayload(const IpcPayload& payload, const IpcWriteOptions& options,
572 io::OutputStream* dst, int32_t* metadata_length) {
573 RETURN_NOT_OK(WriteMessage(*payload.metadata, options, dst, metadata_length));
574
575 #ifndef NDEBUG
576 RETURN_NOT_OK(CheckAligned(dst));
577 #endif
578
579 // Now write the buffers
580 for (size_t i = 0; i < payload.body_buffers.size(); ++i) {
581 const std::shared_ptr<Buffer>& buffer = payload.body_buffers[i];
582 int64_t size = 0;
583 int64_t padding = 0;
584
585 // The buffer might be null if we are handling zero row lengths.
586 if (buffer) {
587 size = buffer->size();
588 padding = BitUtil::RoundUpToMultipleOf8(size) - size;
589 }
590
591 if (size > 0) {
592 RETURN_NOT_OK(dst->Write(buffer));
593 }
594
595 if (padding > 0) {
596 RETURN_NOT_OK(dst->Write(kPaddingBytes, padding));
597 }
598 }
599
600 #ifndef NDEBUG
601 RETURN_NOT_OK(CheckAligned(dst));
602 #endif
603
604 return Status::OK();
605 }
606
607 Status GetSchemaPayload(const Schema& schema, const IpcWriteOptions& options,
608 const DictionaryFieldMapper& mapper, IpcPayload* out) {
609 out->type = MessageType::SCHEMA;
610 return internal::WriteSchemaMessage(schema, mapper, options, &out->metadata);
611 }
612
613 Status GetDictionaryPayload(int64_t id, const std::shared_ptr<Array>& dictionary,
614 const IpcWriteOptions& options, IpcPayload* out) {
615 return GetDictionaryPayload(id, false, dictionary, options, out);
616 }
617
618 Status GetDictionaryPayload(int64_t id, bool is_delta,
619 const std::shared_ptr<Array>& dictionary,
620 const IpcWriteOptions& options, IpcPayload* out) {
621 out->type = MessageType::DICTIONARY_BATCH;
622 // Frame of reference is 0, see ARROW-384
623 DictionarySerializer assembler(id, is_delta, /*buffer_start_offset=*/0, options, out);
624 return assembler.Assemble(dictionary);
625 }
626
627 Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options,
628 IpcPayload* out) {
629 out->type = MessageType::RECORD_BATCH;
630 RecordBatchSerializer assembler(/*buffer_start_offset=*/0, options, out);
631 return assembler.Assemble(batch);
632 }
633
634 Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset,
635 io::OutputStream* dst, int32_t* metadata_length,
636 int64_t* body_length, const IpcWriteOptions& options) {
637 IpcPayload payload;
638 RecordBatchSerializer assembler(buffer_start_offset, options, &payload);
639 RETURN_NOT_OK(assembler.Assemble(batch));
640
641 // TODO: it's a rough edge that the metadata and body length here are
642 // computed separately
643
644 // The body size is computed in the payload
645 *body_length = payload.body_length;
646
647 return WriteIpcPayload(payload, options, dst, metadata_length);
648 }
649
650 Status WriteRecordBatchStream(const std::vector<std::shared_ptr<RecordBatch>>& batches,
651 const IpcWriteOptions& options, io::OutputStream* dst) {
652 ASSIGN_OR_RAISE(std::shared_ptr<RecordBatchWriter> writer,
653 MakeStreamWriter(dst, batches[0]->schema(), options));
654 for (const auto& batch : batches) {
655 DCHECK(batch->schema()->Equals(*batches[0]->schema())) << "Schemas unequal";
656 RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
657 }
658 RETURN_NOT_OK(writer->Close());
659 return Status::OK();
660 }
661
662 namespace {
663
664 Status WriteTensorHeader(const Tensor& tensor, io::OutputStream* dst,
665 int32_t* metadata_length) {
666 IpcWriteOptions options;
667 options.alignment = kTensorAlignment;
668 std::shared_ptr<Buffer> metadata;
669 ARROW_ASSIGN_OR_RAISE(metadata, internal::WriteTensorMessage(tensor, 0, options));
670 return WriteMessage(*metadata, options, dst, metadata_length);
671 }
672
673 Status WriteStridedTensorData(int dim_index, int64_t offset, int elem_size,
674 const Tensor& tensor, uint8_t* scratch_space,
675 io::OutputStream* dst) {
676 if (dim_index == tensor.ndim() - 1) {
677 const uint8_t* data_ptr = tensor.raw_data() + offset;
678 const int64_t stride = tensor.strides()[dim_index];
679 for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) {
680 memcpy(scratch_space + i * elem_size, data_ptr, elem_size);
681 data_ptr += stride;
682 }
683 return dst->Write(scratch_space, elem_size * tensor.shape()[dim_index]);
684 }
685 for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) {
686 RETURN_NOT_OK(WriteStridedTensorData(dim_index + 1, offset, elem_size, tensor,
687 scratch_space, dst));
688 offset += tensor.strides()[dim_index];
689 }
690 return Status::OK();
691 }
692
693 Status GetContiguousTensor(const Tensor& tensor, MemoryPool* pool,
694 std::unique_ptr<Tensor>* out) {
695 const int elem_size = GetByteWidth(*tensor.type());
696
697 ARROW_ASSIGN_OR_RAISE(
698 auto scratch_space,
699 AllocateBuffer(tensor.shape()[tensor.ndim() - 1] * elem_size, pool));
700
701 ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> contiguous_data,
702 AllocateResizableBuffer(tensor.size() * elem_size, pool));
703
704 io::BufferOutputStream stream(contiguous_data);
705 RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size, tensor,
706 scratch_space->mutable_data(), &stream));
707
708 out->reset(new Tensor(tensor.type(), contiguous_data, tensor.shape()));
709
710 return Status::OK();
711 }
712
713 } // namespace
714
715 Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length,
716 int64_t* body_length) {
717 const int elem_size = GetByteWidth(*tensor.type());
718
719 *body_length = tensor.size() * elem_size;
720
721 // Tensor metadata accounts for padding
722 if (tensor.is_contiguous()) {
723 RETURN_NOT_OK(WriteTensorHeader(tensor, dst, metadata_length));
724 auto data = tensor.data();
725 if (data && data->data()) {
726 RETURN_NOT_OK(dst->Write(data->data(), *body_length));
727 } else {
728 *body_length = 0;
729 }
730 } else {
731 // The tensor written is made contiguous
732 Tensor dummy(tensor.type(), nullptr, tensor.shape());
733 RETURN_NOT_OK(WriteTensorHeader(dummy, dst, metadata_length));
734
735 // TODO: Do we care enough about this temporary allocation to pass in a
736 // MemoryPool to this function?
737 ARROW_ASSIGN_OR_RAISE(auto scratch_space,
738 AllocateBuffer(tensor.shape()[tensor.ndim() - 1] * elem_size));
739
740 RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size, tensor,
741 scratch_space->mutable_data(), dst));
742 }
743
744 return Status::OK();
745 }
746
747 Result<std::unique_ptr<Message>> GetTensorMessage(const Tensor& tensor,
748 MemoryPool* pool) {
749 const Tensor* tensor_to_write = &tensor;
750 std::unique_ptr<Tensor> temp_tensor;
751
752 if (!tensor.is_contiguous()) {
753 RETURN_NOT_OK(GetContiguousTensor(tensor, pool, &temp_tensor));
754 tensor_to_write = temp_tensor.get();
755 }
756
757 IpcWriteOptions options;
758 options.alignment = kTensorAlignment;
759 std::shared_ptr<Buffer> metadata;
760 ARROW_ASSIGN_OR_RAISE(metadata,
761 internal::WriteTensorMessage(*tensor_to_write, 0, options));
762 return std::unique_ptr<Message>(new Message(metadata, tensor_to_write->data()));
763 }
764
765 namespace internal {
766
767 class SparseTensorSerializer {
768 public:
769 SparseTensorSerializer(int64_t buffer_start_offset, IpcPayload* out)
770 : out_(out),
771 buffer_start_offset_(buffer_start_offset),
772 options_(IpcWriteOptions::Defaults()) {}
773
774 ~SparseTensorSerializer() = default;
775
776 Status VisitSparseIndex(const SparseIndex& sparse_index) {
777 switch (sparse_index.format_id()) {
778 case SparseTensorFormat::COO:
779 RETURN_NOT_OK(
780 VisitSparseCOOIndex(checked_cast<const SparseCOOIndex&>(sparse_index)));
781 break;
782
783 case SparseTensorFormat::CSR:
784 RETURN_NOT_OK(
785 VisitSparseCSRIndex(checked_cast<const SparseCSRIndex&>(sparse_index)));
786 break;
787
788 case SparseTensorFormat::CSC:
789 RETURN_NOT_OK(
790 VisitSparseCSCIndex(checked_cast<const SparseCSCIndex&>(sparse_index)));
791 break;
792
793 case SparseTensorFormat::CSF:
794 RETURN_NOT_OK(
795 VisitSparseCSFIndex(checked_cast<const SparseCSFIndex&>(sparse_index)));
796 break;
797
798 default:
799 std::stringstream ss;
800 ss << "Unable to convert type: " << sparse_index.ToString() << std::endl;
801 return Status::NotImplemented(ss.str());
802 }
803
804 return Status::OK();
805 }
806
807 Status SerializeMetadata(const SparseTensor& sparse_tensor) {
808 return WriteSparseTensorMessage(sparse_tensor, out_->body_length, buffer_meta_,
809 options_)
810 .Value(&out_->metadata);
811 }
812
813 Status Assemble(const SparseTensor& sparse_tensor) {
814 if (buffer_meta_.size() > 0) {
815 buffer_meta_.clear();
816 out_->body_buffers.clear();
817 }
818
819 RETURN_NOT_OK(VisitSparseIndex(*sparse_tensor.sparse_index()));
820 out_->body_buffers.emplace_back(sparse_tensor.data());
821
822 int64_t offset = buffer_start_offset_;
823 buffer_meta_.reserve(out_->body_buffers.size());
824
825 for (size_t i = 0; i < out_->body_buffers.size(); ++i) {
826 const Buffer* buffer = out_->body_buffers[i].get();
827 int64_t size = buffer->size();
828 int64_t padding = BitUtil::RoundUpToMultipleOf8(size) - size;
829 buffer_meta_.push_back({offset, size + padding});
830 offset += size + padding;
831 }
832
833 out_->body_length = offset - buffer_start_offset_;
834 DCHECK(BitUtil::IsMultipleOf8(out_->body_length));
835
836 return SerializeMetadata(sparse_tensor);
837 }
838
839 private:
840 Status VisitSparseCOOIndex(const SparseCOOIndex& sparse_index) {
841 out_->body_buffers.emplace_back(sparse_index.indices()->data());
842 return Status::OK();
843 }
844
845 Status VisitSparseCSRIndex(const SparseCSRIndex& sparse_index) {
846 out_->body_buffers.emplace_back(sparse_index.indptr()->data());
847 out_->body_buffers.emplace_back(sparse_index.indices()->data());
848 return Status::OK();
849 }
850
851 Status VisitSparseCSCIndex(const SparseCSCIndex& sparse_index) {
852 out_->body_buffers.emplace_back(sparse_index.indptr()->data());
853 out_->body_buffers.emplace_back(sparse_index.indices()->data());
854 return Status::OK();
855 }
856
857 Status VisitSparseCSFIndex(const SparseCSFIndex& sparse_index) {
858 for (const std::shared_ptr<arrow::Tensor>& indptr : sparse_index.indptr()) {
859 out_->body_buffers.emplace_back(indptr->data());
860 }
861 for (const std::shared_ptr<arrow::Tensor>& indices : sparse_index.indices()) {
862 out_->body_buffers.emplace_back(indices->data());
863 }
864 return Status::OK();
865 }
866
867 IpcPayload* out_;
868
869 std::vector<internal::BufferMetadata> buffer_meta_;
870 int64_t buffer_start_offset_;
871 IpcWriteOptions options_;
872 };
873
874 } // namespace internal
875
876 Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* dst,
877 int32_t* metadata_length, int64_t* body_length) {
878 IpcPayload payload;
879 internal::SparseTensorSerializer writer(0, &payload);
880 RETURN_NOT_OK(writer.Assemble(sparse_tensor));
881
882 *body_length = payload.body_length;
883 return WriteIpcPayload(payload, IpcWriteOptions::Defaults(), dst, metadata_length);
884 }
885
886 Status GetSparseTensorPayload(const SparseTensor& sparse_tensor, MemoryPool* pool,
887 IpcPayload* out) {
888 internal::SparseTensorSerializer writer(0, out);
889 return writer.Assemble(sparse_tensor);
890 }
891
892 Result<std::unique_ptr<Message>> GetSparseTensorMessage(const SparseTensor& sparse_tensor,
893 MemoryPool* pool) {
894 IpcPayload payload;
895 RETURN_NOT_OK(GetSparseTensorPayload(sparse_tensor, pool, &payload));
896 return std::unique_ptr<Message>(
897 new Message(std::move(payload.metadata), std::move(payload.body_buffers[0])));
898 }
899
900 int64_t GetPayloadSize(const IpcPayload& payload, const IpcWriteOptions& options) {
901 const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8;
902 const int32_t flatbuffer_size = static_cast<int32_t>(payload.metadata->size());
903 const int32_t padded_message_length = static_cast<int32_t>(
904 PaddedLength(flatbuffer_size + prefix_size, options.alignment));
905 // body_length already accounts for padding
906 return payload.body_length + padded_message_length;
907 }
908
909 Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size) {
910 return GetRecordBatchSize(batch, IpcWriteOptions::Defaults(), size);
911 }
912
913 Status GetRecordBatchSize(const RecordBatch& batch, const IpcWriteOptions& options,
914 int64_t* size) {
915 // emulates the behavior of Write without actually writing
916 int32_t metadata_length = 0;
917 int64_t body_length = 0;
918 io::MockOutputStream dst;
919 RETURN_NOT_OK(
920 WriteRecordBatch(batch, 0, &dst, &metadata_length, &body_length, options));
921 *size = dst.GetExtentBytesWritten();
922 return Status::OK();
923 }
924
925 Status GetTensorSize(const Tensor& tensor, int64_t* size) {
926 // emulates the behavior of Write without actually writing
927 int32_t metadata_length = 0;
928 int64_t body_length = 0;
929 io::MockOutputStream dst;
930 RETURN_NOT_OK(WriteTensor(tensor, &dst, &metadata_length, &body_length));
931 *size = dst.GetExtentBytesWritten();
932 return Status::OK();
933 }
934
935 // ----------------------------------------------------------------------
936
937 RecordBatchWriter::~RecordBatchWriter() {}
938
939 Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize) {
940 TableBatchReader reader(table);
941
942 if (max_chunksize > 0) {
943 reader.set_chunksize(max_chunksize);
944 }
945
946 std::shared_ptr<RecordBatch> batch;
947 while (true) {
948 RETURN_NOT_OK(reader.ReadNext(&batch));
949 if (batch == nullptr) {
950 break;
951 }
952 RETURN_NOT_OK(WriteRecordBatch(*batch));
953 }
954
955 return Status::OK();
956 }
957
958 Status RecordBatchWriter::WriteTable(const Table& table) { return WriteTable(table, -1); }
959
960 // ----------------------------------------------------------------------
961 // Payload writer implementation
962
963 namespace internal {
964
965 IpcPayloadWriter::~IpcPayloadWriter() {}
966
967 Status IpcPayloadWriter::Start() { return Status::OK(); }
968
969 class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter {
970 public:
971 // A RecordBatchWriter implementation that writes to a IpcPayloadWriter.
972 IpcFormatWriter(std::unique_ptr<internal::IpcPayloadWriter> payload_writer,
973 const Schema& schema, const IpcWriteOptions& options,
974 bool is_file_format)
975 : payload_writer_(std::move(payload_writer)),
976 schema_(schema),
977 mapper_(schema),
978 is_file_format_(is_file_format),
979 options_(options) {}
980
981 // A Schema-owning constructor variant
982 IpcFormatWriter(std::unique_ptr<internal::IpcPayloadWriter> payload_writer,
983 const std::shared_ptr<Schema>& schema, const IpcWriteOptions& options,
984 bool is_file_format)
985 : IpcFormatWriter(std::move(payload_writer), *schema, options, is_file_format) {
986 shared_schema_ = schema;
987 }
988
989 Status WriteRecordBatch(const RecordBatch& batch) override {
990 if (!batch.schema()->Equals(schema_, false /* check_metadata */)) {
991 return Status::Invalid("Tried to write record batch with different schema");
992 }
993
994 RETURN_NOT_OK(CheckStarted());
995
996 RETURN_NOT_OK(WriteDictionaries(batch));
997
998 IpcPayload payload;
999 RETURN_NOT_OK(GetRecordBatchPayload(batch, options_, &payload));
1000 RETURN_NOT_OK(WritePayload(payload));
1001 ++stats_.num_record_batches;
1002 return Status::OK();
1003 }
1004
1005 Status WriteTable(const Table& table, int64_t max_chunksize) override {
1006 if (is_file_format_ && options_.unify_dictionaries) {
1007 ARROW_ASSIGN_OR_RAISE(auto unified_table,
1008 DictionaryUnifier::UnifyTable(table, options_.memory_pool));
1009 return RecordBatchWriter::WriteTable(*unified_table, max_chunksize);
1010 } else {
1011 return RecordBatchWriter::WriteTable(table, max_chunksize);
1012 }
1013 }
1014
1015 Status Close() override {
1016 RETURN_NOT_OK(CheckStarted());
1017 return payload_writer_->Close();
1018 }
1019
1020 Status Start() {
1021 started_ = true;
1022 RETURN_NOT_OK(payload_writer_->Start());
1023
1024 IpcPayload payload;
1025 RETURN_NOT_OK(GetSchemaPayload(schema_, options_, mapper_, &payload));
1026 return WritePayload(payload);
1027 }
1028
1029 WriteStats stats() const override { return stats_; }
1030
1031 protected:
1032 Status CheckStarted() {
1033 if (!started_) {
1034 return Start();
1035 }
1036 return Status::OK();
1037 }
1038
1039 Status WriteDictionaries(const RecordBatch& batch) {
1040 ARROW_ASSIGN_OR_RAISE(const auto dictionaries, CollectDictionaries(batch, mapper_));
1041 const auto equal_options = EqualOptions().nans_equal(true);
1042
1043 for (const auto& pair : dictionaries) {
1044 int64_t dictionary_id = pair.first;
1045 const auto& dictionary = pair.second;
1046
1047 // If a dictionary with this id was already emitted, check if it was the same.
1048 auto* last_dictionary = &last_dictionaries_[dictionary_id];
1049 const bool dictionary_exists = (*last_dictionary != nullptr);
1050 int64_t delta_start = 0;
1051 if (dictionary_exists) {
1052 if ((*last_dictionary)->data() == dictionary->data()) {
1053 // Fast shortcut for a common case.
1054 // Same dictionary data by pointer => no need to emit it again
1055 continue;
1056 }
1057 const int64_t last_length = (*last_dictionary)->length();
1058 const int64_t new_length = dictionary->length();
1059 if (new_length == last_length &&
1060 ((*last_dictionary)->Equals(dictionary, equal_options))) {
1061 // Same dictionary by value => no need to emit it again
1062 // (while this can have a CPU cost, this code path is required
1063 // for the IPC file format)
1064 continue;
1065 }
1066 if (is_file_format_) {
1067 return Status::Invalid(
1068 "Dictionary replacement detected when writing IPC file format. "
1069 "Arrow IPC files only support a single dictionary for a given field "
1070 "across all batches.");
1071 }
1072
1073 // (the read path doesn't support outer dictionary deltas, don't emit them)
1074 if (new_length > last_length && options_.emit_dictionary_deltas &&
1075 !HasNestedDict(*dictionary->data()) &&
1076 ((*last_dictionary)
1077 ->RangeEquals(dictionary, 0, last_length, 0, equal_options))) {
1078 // New dictionary starts with the current dictionary
1079 delta_start = last_length;
1080 }
1081 }
1082
1083 IpcPayload payload;
1084 if (delta_start) {
1085 RETURN_NOT_OK(GetDictionaryPayload(dictionary_id, /*is_delta=*/true,
1086 dictionary->Slice(delta_start), options_,
1087 &payload));
1088 } else {
1089 RETURN_NOT_OK(
1090 GetDictionaryPayload(dictionary_id, dictionary, options_, &payload));
1091 }
1092 RETURN_NOT_OK(WritePayload(payload));
1093 ++stats_.num_dictionary_batches;
1094 if (dictionary_exists) {
1095 if (delta_start) {
1096 ++stats_.num_dictionary_deltas;
1097 } else {
1098 ++stats_.num_replaced_dictionaries;
1099 }
1100 }
1101
1102 // Remember dictionary for next batches
1103 *last_dictionary = dictionary;
1104 }
1105 return Status::OK();
1106 }
1107
1108 Status WritePayload(const IpcPayload& payload) {
1109 RETURN_NOT_OK(payload_writer_->WritePayload(payload));
1110 ++stats_.num_messages;
1111 return Status::OK();
1112 }
1113
1114 std::unique_ptr<IpcPayloadWriter> payload_writer_;
1115 std::shared_ptr<Schema> shared_schema_;
1116 const Schema& schema_;
1117 const DictionaryFieldMapper mapper_;
1118 const bool is_file_format_;
1119
1120 // A map of last-written dictionaries by id.
1121 // This is required to avoid the same dictionary again and again,
1122 // and also for correctness when writing the IPC file format
1123 // (where replacements and deltas are unsupported).
1124 // The latter is also why we can't use weak_ptr.
1125 std::unordered_map<int64_t, std::shared_ptr<Array>> last_dictionaries_;
1126
1127 bool started_ = false;
1128 IpcWriteOptions options_;
1129 WriteStats stats_;
1130 };
1131
1132 class StreamBookKeeper {
1133 public:
1134 StreamBookKeeper(const IpcWriteOptions& options, io::OutputStream* sink)
1135 : options_(options), sink_(sink), position_(-1) {}
1136 StreamBookKeeper(const IpcWriteOptions& options, std::shared_ptr<io::OutputStream> sink)
1137 : options_(options),
1138 sink_(sink.get()),
1139 owned_sink_(std::move(sink)),
1140 position_(-1) {}
1141
1142 Status UpdatePosition() { return sink_->Tell().Value(&position_); }
1143
1144 Status UpdatePositionCheckAligned() {
1145 RETURN_NOT_OK(UpdatePosition());
1146 DCHECK_EQ(0, position_ % 8) << "Stream is not aligned";
1147 return Status::OK();
1148 }
1149
1150 Status Align(int32_t alignment = kArrowIpcAlignment) {
1151 // Adds padding bytes if necessary to ensure all memory blocks are written on
1152 // 8-byte (or other alignment) boundaries.
1153 int64_t remainder = PaddedLength(position_, alignment) - position_;
1154 if (remainder > 0) {
1155 return Write(kPaddingBytes, remainder);
1156 }
1157 return Status::OK();
1158 }
1159
1160 // Write data and update position
1161 Status Write(const void* data, int64_t nbytes) {
1162 RETURN_NOT_OK(sink_->Write(data, nbytes));
1163 position_ += nbytes;
1164 return Status::OK();
1165 }
1166
1167 Status WriteEOS() {
1168 // End of stream marker
1169 constexpr int32_t kZeroLength = 0;
1170 if (!options_.write_legacy_ipc_format) {
1171 RETURN_NOT_OK(Write(&kIpcContinuationToken, sizeof(int32_t)));
1172 }
1173 return Write(&kZeroLength, sizeof(int32_t));
1174 }
1175
1176 protected:
1177 IpcWriteOptions options_;
1178 io::OutputStream* sink_;
1179 std::shared_ptr<io::OutputStream> owned_sink_;
1180 int64_t position_;
1181 };
1182
1183 /// A IpcPayloadWriter implementation that writes to an IPC stream
1184 /// (with an end-of-stream marker)
1185 class PayloadStreamWriter : public IpcPayloadWriter, protected StreamBookKeeper {
1186 public:
1187 PayloadStreamWriter(io::OutputStream* sink,
1188 const IpcWriteOptions& options = IpcWriteOptions::Defaults())
1189 : StreamBookKeeper(options, sink) {}
1190 PayloadStreamWriter(std::shared_ptr<io::OutputStream> sink,
1191 const IpcWriteOptions& options = IpcWriteOptions::Defaults())
1192 : StreamBookKeeper(options, std::move(sink)) {}
1193
1194 ~PayloadStreamWriter() override = default;
1195
1196 Status WritePayload(const IpcPayload& payload) override {
1197 #ifndef NDEBUG
1198 // Catch bug fixed in ARROW-3236
1199 RETURN_NOT_OK(UpdatePositionCheckAligned());
1200 #endif
1201
1202 int32_t metadata_length = 0; // unused
1203 RETURN_NOT_OK(WriteIpcPayload(payload, options_, sink_, &metadata_length));
1204 RETURN_NOT_OK(UpdatePositionCheckAligned());
1205 return Status::OK();
1206 }
1207
1208 Status Close() override { return WriteEOS(); }
1209 };
1210
1211 /// A IpcPayloadWriter implementation that writes to a IPC file
1212 /// (with a footer as defined in File.fbs)
1213 class PayloadFileWriter : public internal::IpcPayloadWriter, protected StreamBookKeeper {
1214 public:
1215 PayloadFileWriter(const IpcWriteOptions& options, const std::shared_ptr<Schema>& schema,
1216 const std::shared_ptr<const KeyValueMetadata>& metadata,
1217 io::OutputStream* sink)
1218 : StreamBookKeeper(options, sink), schema_(schema), metadata_(metadata) {}
1219 PayloadFileWriter(const IpcWriteOptions& options, const std::shared_ptr<Schema>& schema,
1220 const std::shared_ptr<const KeyValueMetadata>& metadata,
1221 std::shared_ptr<io::OutputStream> sink)
1222 : StreamBookKeeper(options, std::move(sink)),
1223 schema_(schema),
1224 metadata_(metadata) {}
1225
1226 ~PayloadFileWriter() override = default;
1227
1228 Status WritePayload(const IpcPayload& payload) override {
1229 #ifndef NDEBUG
1230 // Catch bug fixed in ARROW-3236
1231 RETURN_NOT_OK(UpdatePositionCheckAligned());
1232 #endif
1233
1234 // Metadata length must include padding, it's computed by WriteIpcPayload()
1235 FileBlock block = {position_, 0, payload.body_length};
1236 RETURN_NOT_OK(WriteIpcPayload(payload, options_, sink_, &block.metadata_length));
1237 RETURN_NOT_OK(UpdatePositionCheckAligned());
1238
1239 // Record position and size of some message types, to list them in the footer
1240 switch (payload.type) {
1241 case MessageType::DICTIONARY_BATCH:
1242 dictionaries_.push_back(block);
1243 break;
1244 case MessageType::RECORD_BATCH:
1245 record_batches_.push_back(block);
1246 break;
1247 default:
1248 break;
1249 }
1250
1251 return Status::OK();
1252 }
1253
1254 Status Start() override {
1255 // ARROW-3236: The initial position -1 needs to be updated to the stream's
1256 // current position otherwise an incorrect amount of padding will be
1257 // written to new files.
1258 RETURN_NOT_OK(UpdatePosition());
1259
1260 // It is only necessary to align to 8-byte boundary at the start of the file
1261 RETURN_NOT_OK(Write(kArrowMagicBytes, strlen(kArrowMagicBytes)));
1262 RETURN_NOT_OK(Align());
1263
1264 return Status::OK();
1265 }
1266
1267 Status Close() override {
1268 // Write 0 EOS message for compatibility with sequential readers
1269 RETURN_NOT_OK(WriteEOS());
1270
1271 // Write file footer
1272 RETURN_NOT_OK(UpdatePosition());
1273 int64_t initial_position = position_;
1274 RETURN_NOT_OK(
1275 WriteFileFooter(*schema_, dictionaries_, record_batches_, metadata_, sink_));
1276
1277 // Write footer length
1278 RETURN_NOT_OK(UpdatePosition());
1279 int32_t footer_length = static_cast<int32_t>(position_ - initial_position);
1280 if (footer_length <= 0) {
1281 return Status::Invalid("Invalid file footer");
1282 }
1283
1284 // write footer length in little endian
1285 footer_length = BitUtil::ToLittleEndian(footer_length);
1286 RETURN_NOT_OK(Write(&footer_length, sizeof(int32_t)));
1287
1288 // Write magic bytes to end file
1289 return Write(kArrowMagicBytes, strlen(kArrowMagicBytes));
1290 }
1291
1292 protected:
1293 std::shared_ptr<Schema> schema_;
1294 std::shared_ptr<const KeyValueMetadata> metadata_;
1295 std::vector<FileBlock> dictionaries_;
1296 std::vector<FileBlock> record_batches_;
1297 };
1298
1299 } // namespace internal
1300
1301 Result<std::shared_ptr<RecordBatchWriter>> MakeStreamWriter(
1302 io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
1303 const IpcWriteOptions& options) {
1304 return std::make_shared<internal::IpcFormatWriter>(
1305 ::arrow::internal::make_unique<internal::PayloadStreamWriter>(sink, options),
1306 schema, options, /*is_file_format=*/false);
1307 }
1308
1309 Result<std::shared_ptr<RecordBatchWriter>> MakeStreamWriter(
1310 std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema,
1311 const IpcWriteOptions& options) {
1312 return std::make_shared<internal::IpcFormatWriter>(
1313 ::arrow::internal::make_unique<internal::PayloadStreamWriter>(std::move(sink),
1314 options),
1315 schema, options, /*is_file_format=*/false);
1316 }
1317
1318 Result<std::shared_ptr<RecordBatchWriter>> NewStreamWriter(
1319 io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
1320 const IpcWriteOptions& options) {
1321 return MakeStreamWriter(sink, schema, options);
1322 }
1323
1324 Result<std::shared_ptr<RecordBatchWriter>> MakeFileWriter(
1325 io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
1326 const IpcWriteOptions& options,
1327 const std::shared_ptr<const KeyValueMetadata>& metadata) {
1328 return std::make_shared<internal::IpcFormatWriter>(
1329 ::arrow::internal::make_unique<internal::PayloadFileWriter>(options, schema,
1330 metadata, sink),
1331 schema, options, /*is_file_format=*/true);
1332 }
1333
1334 Result<std::shared_ptr<RecordBatchWriter>> MakeFileWriter(
1335 std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema,
1336 const IpcWriteOptions& options,
1337 const std::shared_ptr<const KeyValueMetadata>& metadata) {
1338 return std::make_shared<internal::IpcFormatWriter>(
1339 ::arrow::internal::make_unique<internal::PayloadFileWriter>(
1340 options, schema, metadata, std::move(sink)),
1341 schema, options, /*is_file_format=*/true);
1342 }
1343
1344 Result<std::shared_ptr<RecordBatchWriter>> NewFileWriter(
1345 io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
1346 const IpcWriteOptions& options,
1347 const std::shared_ptr<const KeyValueMetadata>& metadata) {
1348 return MakeFileWriter(sink, schema, options, metadata);
1349 }
1350
1351 namespace internal {
1352
1353 Result<std::unique_ptr<RecordBatchWriter>> OpenRecordBatchWriter(
1354 std::unique_ptr<IpcPayloadWriter> sink, const std::shared_ptr<Schema>& schema,
1355 const IpcWriteOptions& options) {
1356 // XXX should we call Start()?
1357 return ::arrow::internal::make_unique<internal::IpcFormatWriter>(
1358 std::move(sink), schema, options, /*is_file_format=*/false);
1359 }
1360
1361 Result<std::unique_ptr<IpcPayloadWriter>> MakePayloadStreamWriter(
1362 io::OutputStream* sink, const IpcWriteOptions& options) {
1363 return ::arrow::internal::make_unique<internal::PayloadStreamWriter>(sink, options);
1364 }
1365
1366 Result<std::unique_ptr<IpcPayloadWriter>> MakePayloadFileWriter(
1367 io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
1368 const IpcWriteOptions& options,
1369 const std::shared_ptr<const KeyValueMetadata>& metadata) {
1370 return ::arrow::internal::make_unique<internal::PayloadFileWriter>(options, schema,
1371 metadata, sink);
1372 }
1373
1374 } // namespace internal
1375
1376 // ----------------------------------------------------------------------
1377 // Serialization public APIs
1378
1379 Result<std::shared_ptr<Buffer>> SerializeRecordBatch(const RecordBatch& batch,
1380 std::shared_ptr<MemoryManager> mm) {
1381 auto options = IpcWriteOptions::Defaults();
1382 int64_t size = 0;
1383 RETURN_NOT_OK(GetRecordBatchSize(batch, options, &size));
1384 ARROW_ASSIGN_OR_RAISE(auto buffer, mm->AllocateBuffer(size));
1385 ARROW_ASSIGN_OR_RAISE(auto writer, Buffer::GetWriter(buffer));
1386
1387 // XXX Should we have a helper function for getting a MemoryPool
1388 // for any MemoryManager (not only CPU)?
1389 if (mm->is_cpu()) {
1390 options.memory_pool = checked_pointer_cast<CPUMemoryManager>(mm)->pool();
1391 }
1392 RETURN_NOT_OK(SerializeRecordBatch(batch, options, writer.get()));
1393 RETURN_NOT_OK(writer->Close());
1394 return buffer;
1395 }
1396
1397 Result<std::shared_ptr<Buffer>> SerializeRecordBatch(const RecordBatch& batch,
1398 const IpcWriteOptions& options) {
1399 int64_t size = 0;
1400 RETURN_NOT_OK(GetRecordBatchSize(batch, options, &size));
1401 ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> buffer,
1402 AllocateBuffer(size, options.memory_pool));
1403
1404 io::FixedSizeBufferWriter stream(buffer);
1405 RETURN_NOT_OK(SerializeRecordBatch(batch, options, &stream));
1406 return buffer;
1407 }
1408
1409 Status SerializeRecordBatch(const RecordBatch& batch, const IpcWriteOptions& options,
1410 io::OutputStream* out) {
1411 int32_t metadata_length = 0;
1412 int64_t body_length = 0;
1413 return WriteRecordBatch(batch, 0, out, &metadata_length, &body_length, options);
1414 }
1415
1416 Result<std::shared_ptr<Buffer>> SerializeSchema(const Schema& schema, MemoryPool* pool) {
1417 ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create(1024, pool));
1418
1419 auto options = IpcWriteOptions::Defaults();
1420 const bool is_file_format = false; // indifferent as we don't write dictionaries
1421 internal::IpcFormatWriter writer(
1422 ::arrow::internal::make_unique<internal::PayloadStreamWriter>(stream.get()), schema,
1423 options, is_file_format);
1424 RETURN_NOT_OK(writer.Start());
1425 return stream->Finish();
1426 }
1427
1428 } // namespace ipc
1429 } // namespace arrow