1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include "arrow/ipc/writer.h"
26 #include <type_traits>
27 #include <unordered_map>
31 #include "arrow/array.h"
32 #include "arrow/buffer.h"
33 #include "arrow/device.h"
34 #include "arrow/extension_type.h"
35 #include "arrow/io/interfaces.h"
36 #include "arrow/io/memory.h"
37 #include "arrow/ipc/dictionary.h"
38 #include "arrow/ipc/message.h"
39 #include "arrow/ipc/metadata_internal.h"
40 #include "arrow/ipc/util.h"
41 #include "arrow/record_batch.h"
42 #include "arrow/result_internal.h"
43 #include "arrow/sparse_tensor.h"
44 #include "arrow/status.h"
45 #include "arrow/table.h"
46 #include "arrow/type.h"
47 #include "arrow/type_traits.h"
48 #include "arrow/util/bit_util.h"
49 #include "arrow/util/bitmap_ops.h"
50 #include "arrow/util/checked_cast.h"
51 #include "arrow/util/compression.h"
52 #include "arrow/util/endian.h"
53 #include "arrow/util/key_value_metadata.h"
54 #include "arrow/util/logging.h"
55 #include "arrow/util/make_unique.h"
56 #include "arrow/util/parallel.h"
57 #include "arrow/visitor_inline.h"
61 using internal::checked_cast
;
62 using internal::checked_pointer_cast
;
63 using internal::CopyBitmap
;
64 using internal::GetByteWidth
;
68 using internal::FileBlock
;
69 using internal::kArrowMagicBytes
;
73 bool HasNestedDict(const ArrayData
& data
) {
74 if (data
.type
->id() == Type::DICTIONARY
) {
77 for (const auto& child
: data
.child_data
) {
78 if (HasNestedDict(*child
)) {
85 Status
GetTruncatedBitmap(int64_t offset
, int64_t length
,
86 const std::shared_ptr
<Buffer
> input
, MemoryPool
* pool
,
87 std::shared_ptr
<Buffer
>* buffer
) {
92 int64_t min_length
= PaddedLength(BitUtil::BytesForBits(length
));
93 if (offset
!= 0 || min_length
< input
->size()) {
94 // With a sliced array / non-zero offset, we must copy the bitmap
95 ARROW_ASSIGN_OR_RAISE(*buffer
, CopyBitmap(pool
, input
->data(), offset
, length
));
102 Status
GetTruncatedBuffer(int64_t offset
, int64_t length
, int32_t byte_width
,
103 const std::shared_ptr
<Buffer
> input
, MemoryPool
* pool
,
104 std::shared_ptr
<Buffer
>* buffer
) {
109 int64_t padded_length
= PaddedLength(length
* byte_width
);
110 if (offset
!= 0 || padded_length
< input
->size()) {
112 SliceBuffer(input
, offset
* byte_width
, std::min(padded_length
, input
->size()));
119 static inline bool NeedTruncate(int64_t offset
, const Buffer
* buffer
,
120 int64_t min_length
) {
121 // buffer can be NULL
122 if (buffer
== nullptr) {
125 return offset
!= 0 || min_length
< buffer
->size();
128 class RecordBatchSerializer
{
130 RecordBatchSerializer(int64_t buffer_start_offset
, const IpcWriteOptions
& options
,
134 max_recursion_depth_(options
.max_recursion_depth
),
135 buffer_start_offset_(buffer_start_offset
) {
136 DCHECK_GT(max_recursion_depth_
, 0);
139 virtual ~RecordBatchSerializer() = default;
141 Status
VisitArray(const Array
& arr
) {
142 static std::shared_ptr
<Buffer
> kNullBuffer
= std::make_shared
<Buffer
>(nullptr, 0);
144 if (max_recursion_depth_
<= 0) {
145 return Status::Invalid("Max recursion depth reached");
148 if (!options_
.allow_64bit
&& arr
.length() > std::numeric_limits
<int32_t>::max()) {
149 return Status::CapacityError("Cannot write arrays larger than 2^31 - 1 in length");
152 // push back all common elements
153 field_nodes_
.push_back({arr
.length(), arr
.null_count(), 0});
155 // In V4, null types have no validity bitmap
156 // In V5 and later, null and union types have no validity bitmap
157 if (internal::HasValidityBitmap(arr
.type_id(), options_
.metadata_version
)) {
158 if (arr
.null_count() > 0) {
159 std::shared_ptr
<Buffer
> bitmap
;
160 RETURN_NOT_OK(GetTruncatedBitmap(arr
.offset(), arr
.length(), arr
.null_bitmap(),
161 options_
.memory_pool
, &bitmap
));
162 out_
->body_buffers
.emplace_back(bitmap
);
164 // Push a dummy zero-length buffer, not to be copied
165 out_
->body_buffers
.emplace_back(kNullBuffer
);
168 return VisitType(arr
);
171 // Override this for writing dictionary metadata
172 virtual Status
SerializeMetadata(int64_t num_rows
) {
173 return WriteRecordBatchMessage(num_rows
, out_
->body_length
, custom_metadata_
,
174 field_nodes_
, buffer_meta_
, options_
, &out_
->metadata
);
177 void AppendCustomMetadata(const std::string
& key
, const std::string
& value
) {
178 if (!custom_metadata_
) {
179 custom_metadata_
= std::make_shared
<KeyValueMetadata
>();
181 custom_metadata_
->Append(key
, value
);
184 Status
CompressBuffer(const Buffer
& buffer
, util::Codec
* codec
,
185 std::shared_ptr
<Buffer
>* out
) {
186 // Convert buffer to uncompressed-length-prefixed compressed buffer
187 int64_t maximum_length
= codec
->MaxCompressedLen(buffer
.size(), buffer
.data());
188 ARROW_ASSIGN_OR_RAISE(auto result
, AllocateBuffer(maximum_length
+ sizeof(int64_t)));
190 int64_t actual_length
;
191 ARROW_ASSIGN_OR_RAISE(actual_length
,
192 codec
->Compress(buffer
.size(), buffer
.data(), maximum_length
,
193 result
->mutable_data() + sizeof(int64_t)));
194 *reinterpret_cast<int64_t*>(result
->mutable_data()) =
195 BitUtil::ToLittleEndian(buffer
.size());
196 *out
= SliceBuffer(std::move(result
), /*offset=*/0, actual_length
+ sizeof(int64_t));
200 Status
CompressBodyBuffers() {
202 internal::CheckCompressionSupported(options_
.codec
->compression_type()));
204 auto CompressOne
= [&](size_t i
) {
205 if (out_
->body_buffers
[i
]->size() > 0) {
206 RETURN_NOT_OK(CompressBuffer(*out_
->body_buffers
[i
], options_
.codec
.get(),
207 &out_
->body_buffers
[i
]));
212 return ::arrow::internal::OptionalParallelFor(
213 options_
.use_threads
, static_cast<int>(out_
->body_buffers
.size()), CompressOne
);
216 Status
Assemble(const RecordBatch
& batch
) {
217 if (field_nodes_
.size() > 0) {
218 field_nodes_
.clear();
219 buffer_meta_
.clear();
220 out_
->body_buffers
.clear();
223 // Perform depth-first traversal of the row-batch
224 for (int i
= 0; i
< batch
.num_columns(); ++i
) {
225 RETURN_NOT_OK(VisitArray(*batch
.column(i
)));
228 if (options_
.codec
!= nullptr) {
229 RETURN_NOT_OK(CompressBodyBuffers());
232 // The position for the start of a buffer relative to the passed frame of
233 // reference. May be 0 or some other position in an address space
234 int64_t offset
= buffer_start_offset_
;
236 buffer_meta_
.reserve(out_
->body_buffers
.size());
238 // Construct the buffer metadata for the record batch header
239 for (const auto& buffer
: out_
->body_buffers
) {
243 // The buffer might be null if we are handling zero row lengths.
245 size
= buffer
->size();
246 padding
= BitUtil::RoundUpToMultipleOf8(size
) - size
;
249 buffer_meta_
.push_back({offset
, size
});
250 offset
+= size
+ padding
;
253 out_
->body_length
= offset
- buffer_start_offset_
;
254 DCHECK(BitUtil::IsMultipleOf8(out_
->body_length
));
256 // Now that we have computed the locations of all of the buffers in shared
257 // memory, the data header can be converted to a flatbuffer and written out
259 // Note: The memory written here is prefixed by the size of the flatbuffer
260 // itself as an int32_t.
261 return SerializeMetadata(batch
.num_rows());
264 template <typename ArrayType
>
265 Status
GetZeroBasedValueOffsets(const ArrayType
& array
,
266 std::shared_ptr
<Buffer
>* value_offsets
) {
267 // Share slicing logic between ListArray, BinaryArray and LargeBinaryArray
268 using offset_type
= typename
ArrayType::offset_type
;
270 auto offsets
= array
.value_offsets();
272 int64_t required_bytes
= sizeof(offset_type
) * (array
.length() + 1);
273 if (array
.offset() != 0) {
274 // If we have a non-zero offset, then the value offsets do not start at
275 // zero. We must a) create a new offsets array with shifted offsets and
276 // b) slice the values array accordingly
278 ARROW_ASSIGN_OR_RAISE(auto shifted_offsets
,
279 AllocateBuffer(required_bytes
, options_
.memory_pool
));
281 offset_type
* dest_offsets
=
282 reinterpret_cast<offset_type
*>(shifted_offsets
->mutable_data());
283 const offset_type start_offset
= array
.value_offset(0);
285 for (int i
= 0; i
< array
.length(); ++i
) {
286 dest_offsets
[i
] = array
.value_offset(i
) - start_offset
;
289 dest_offsets
[array
.length()] = array
.value_offset(array
.length()) - start_offset
;
290 offsets
= std::move(shifted_offsets
);
292 // ARROW-6046: Slice offsets to used extent, in case we have a truncated
294 if (offsets
!= nullptr && offsets
->size() > required_bytes
) {
295 offsets
= SliceBuffer(offsets
, 0, required_bytes
);
298 *value_offsets
= std::move(offsets
);
302 Status
Visit(const BooleanArray
& array
) {
303 std::shared_ptr
<Buffer
> data
;
304 RETURN_NOT_OK(GetTruncatedBitmap(array
.offset(), array
.length(), array
.values(),
305 options_
.memory_pool
, &data
));
306 out_
->body_buffers
.emplace_back(data
);
310 Status
Visit(const NullArray
& array
) { return Status::OK(); }
312 template <typename T
>
313 typename
std::enable_if
<is_number_type
<typename
T::TypeClass
>::value
||
314 is_temporal_type
<typename
T::TypeClass
>::value
||
315 is_fixed_size_binary_type
<typename
T::TypeClass
>::value
,
317 Visit(const T
& array
) {
318 std::shared_ptr
<Buffer
> data
= array
.values();
320 const int64_t type_width
= GetByteWidth(*array
.type());
321 int64_t min_length
= PaddedLength(array
.length() * type_width
);
323 if (NeedTruncate(array
.offset(), data
.get(), min_length
)) {
324 // Non-zero offset, slice the buffer
325 const int64_t byte_offset
= array
.offset() * type_width
;
327 // Send padding if it's available
328 const int64_t buffer_length
=
329 std::min(BitUtil::RoundUpToMultipleOf8(array
.length() * type_width
),
330 data
->size() - byte_offset
);
331 data
= SliceBuffer(data
, byte_offset
, buffer_length
);
333 out_
->body_buffers
.emplace_back(data
);
337 template <typename T
>
338 enable_if_base_binary
<typename
T::TypeClass
, Status
> Visit(const T
& array
) {
339 std::shared_ptr
<Buffer
> value_offsets
;
340 RETURN_NOT_OK(GetZeroBasedValueOffsets
<T
>(array
, &value_offsets
));
341 auto data
= array
.value_data();
343 int64_t total_data_bytes
= 0;
345 total_data_bytes
= array
.value_offset(array
.length()) - array
.value_offset(0);
347 if (NeedTruncate(array
.offset(), data
.get(), total_data_bytes
)) {
348 // Slice the data buffer to include only the range we need now
349 const int64_t start_offset
= array
.value_offset(0);
350 const int64_t slice_length
=
351 std::min(PaddedLength(total_data_bytes
), data
->size() - start_offset
);
352 data
= SliceBuffer(data
, start_offset
, slice_length
);
355 out_
->body_buffers
.emplace_back(value_offsets
);
356 out_
->body_buffers
.emplace_back(data
);
360 template <typename T
>
361 enable_if_base_list
<typename
T::TypeClass
, Status
> Visit(const T
& array
) {
362 using offset_type
= typename
T::offset_type
;
364 std::shared_ptr
<Buffer
> value_offsets
;
365 RETURN_NOT_OK(GetZeroBasedValueOffsets
<T
>(array
, &value_offsets
));
366 out_
->body_buffers
.emplace_back(value_offsets
);
368 --max_recursion_depth_
;
369 std::shared_ptr
<Array
> values
= array
.values();
371 offset_type values_offset
= 0;
372 offset_type values_length
= 0;
374 values_offset
= array
.value_offset(0);
375 values_length
= array
.value_offset(array
.length()) - values_offset
;
378 if (array
.offset() != 0 || values_length
< values
->length()) {
379 // Must also slice the values
380 values
= values
->Slice(values_offset
, values_length
);
382 RETURN_NOT_OK(VisitArray(*values
));
383 ++max_recursion_depth_
;
387 Status
Visit(const FixedSizeListArray
& array
) {
388 --max_recursion_depth_
;
389 auto size
= array
.list_type()->list_size();
390 auto values
= array
.values()->Slice(array
.offset() * size
, array
.length() * size
);
392 RETURN_NOT_OK(VisitArray(*values
));
393 ++max_recursion_depth_
;
397 Status
Visit(const StructArray
& array
) {
398 --max_recursion_depth_
;
399 for (int i
= 0; i
< array
.num_fields(); ++i
) {
400 std::shared_ptr
<Array
> field
= array
.field(i
);
401 RETURN_NOT_OK(VisitArray(*field
));
403 ++max_recursion_depth_
;
407 Status
Visit(const SparseUnionArray
& array
) {
408 const int64_t offset
= array
.offset();
409 const int64_t length
= array
.length();
411 std::shared_ptr
<Buffer
> type_codes
;
412 RETURN_NOT_OK(GetTruncatedBuffer(
413 offset
, length
, static_cast<int32_t>(sizeof(UnionArray::type_code_t
)),
414 array
.type_codes(), options_
.memory_pool
, &type_codes
));
415 out_
->body_buffers
.emplace_back(type_codes
);
417 --max_recursion_depth_
;
418 for (int i
= 0; i
< array
.num_fields(); ++i
) {
419 // Sparse union, slicing is done for us by field()
420 RETURN_NOT_OK(VisitArray(*array
.field(i
)));
422 ++max_recursion_depth_
;
426 Status
Visit(const DenseUnionArray
& array
) {
427 const int64_t offset
= array
.offset();
428 const int64_t length
= array
.length();
430 std::shared_ptr
<Buffer
> type_codes
;
431 RETURN_NOT_OK(GetTruncatedBuffer(
432 offset
, length
, static_cast<int32_t>(sizeof(UnionArray::type_code_t
)),
433 array
.type_codes(), options_
.memory_pool
, &type_codes
));
434 out_
->body_buffers
.emplace_back(type_codes
);
436 --max_recursion_depth_
;
437 const auto& type
= checked_cast
<const UnionType
&>(*array
.type());
439 std::shared_ptr
<Buffer
> value_offsets
;
441 GetTruncatedBuffer(offset
, length
, static_cast<int32_t>(sizeof(int32_t)),
442 array
.value_offsets(), options_
.memory_pool
, &value_offsets
));
444 // The Union type codes are not necessary 0-indexed
446 for (int8_t code
: type
.type_codes()) {
447 if (code
> max_code
) {
452 // Allocate an array of child offsets. Set all to -1 to indicate that we
453 // haven't observed a first occurrence of a particular child yet
454 std::vector
<int32_t> child_offsets(max_code
+ 1, -1);
455 std::vector
<int32_t> child_lengths(max_code
+ 1, 0);
458 // This is an unpleasant case. Because the offsets are different for
459 // each child array, when we have a sliced array, we need to "rebase"
460 // the value_offsets for each array
462 const int32_t* unshifted_offsets
= array
.raw_value_offsets();
463 const int8_t* type_codes
= array
.raw_type_codes();
465 // Allocate the shifted offsets
466 ARROW_ASSIGN_OR_RAISE(
467 auto shifted_offsets_buffer
,
468 AllocateBuffer(length
* sizeof(int32_t), options_
.memory_pool
));
469 int32_t* shifted_offsets
=
470 reinterpret_cast<int32_t*>(shifted_offsets_buffer
->mutable_data());
472 // Offsets may not be ascending, so we need to find out the start offset
474 for (int64_t i
= 0; i
< length
; ++i
) {
475 const uint8_t code
= type_codes
[i
];
476 if (child_offsets
[code
] == -1) {
477 child_offsets
[code
] = unshifted_offsets
[i
];
479 child_offsets
[code
] = std::min(child_offsets
[code
], unshifted_offsets
[i
]);
483 // Now compute shifted offsets by subtracting child offset
484 for (int64_t i
= 0; i
< length
; ++i
) {
485 const int8_t code
= type_codes
[i
];
486 shifted_offsets
[i
] = unshifted_offsets
[i
] - child_offsets
[code
];
487 // Update the child length to account for observed value
488 child_lengths
[code
] = std::max(child_lengths
[code
], shifted_offsets
[i
] + 1);
491 value_offsets
= std::move(shifted_offsets_buffer
);
493 out_
->body_buffers
.emplace_back(value_offsets
);
495 // Visit children and slice accordingly
496 for (int i
= 0; i
< type
.num_fields(); ++i
) {
497 std::shared_ptr
<Array
> child
= array
.field(i
);
499 // TODO: ARROW-809, for sliced unions, tricky to know how much to
500 // truncate the children. For now, we are truncating the children to be
501 // no longer than the parent union.
503 const int8_t code
= type
.type_codes()[i
];
504 const int64_t child_offset
= child_offsets
[code
];
505 const int64_t child_length
= child_lengths
[code
];
507 if (child_offset
> 0) {
508 child
= child
->Slice(child_offset
, child_length
);
509 } else if (child_length
< child
->length()) {
510 // This case includes when child is not encountered at all
511 child
= child
->Slice(0, child_length
);
514 RETURN_NOT_OK(VisitArray(*child
));
516 ++max_recursion_depth_
;
520 Status
Visit(const DictionaryArray
& array
) {
521 // Dictionary written out separately. Slice offset contained in the indices
522 return VisitType(*array
.indices());
525 Status
Visit(const ExtensionArray
& array
) { return VisitType(*array
.storage()); }
527 Status
VisitType(const Array
& values
) { return VisitArrayInline(values
, this); }
530 // Destination for output buffers
533 std::shared_ptr
<KeyValueMetadata
> custom_metadata_
;
535 std::vector
<internal::FieldMetadata
> field_nodes_
;
536 std::vector
<internal::BufferMetadata
> buffer_meta_
;
538 const IpcWriteOptions
& options_
;
539 int64_t max_recursion_depth_
;
540 int64_t buffer_start_offset_
;
543 class DictionarySerializer
: public RecordBatchSerializer
{
545 DictionarySerializer(int64_t dictionary_id
, bool is_delta
, int64_t buffer_start_offset
,
546 const IpcWriteOptions
& options
, IpcPayload
* out
)
547 : RecordBatchSerializer(buffer_start_offset
, options
, out
),
548 dictionary_id_(dictionary_id
),
549 is_delta_(is_delta
) {}
551 Status
SerializeMetadata(int64_t num_rows
) override
{
552 return WriteDictionaryMessage(dictionary_id_
, is_delta_
, num_rows
, out_
->body_length
,
553 custom_metadata_
, field_nodes_
, buffer_meta_
, options_
,
557 Status
Assemble(const std::shared_ptr
<Array
>& dictionary
) {
558 // Make a dummy record batch. A bit tedious as we have to make a schema
559 auto schema
= arrow::schema({arrow::field("dictionary", dictionary
->type())});
560 auto batch
= RecordBatch::Make(std::move(schema
), dictionary
->length(), {dictionary
});
561 return RecordBatchSerializer::Assemble(*batch
);
565 int64_t dictionary_id_
;
571 Status
WriteIpcPayload(const IpcPayload
& payload
, const IpcWriteOptions
& options
,
572 io::OutputStream
* dst
, int32_t* metadata_length
) {
573 RETURN_NOT_OK(WriteMessage(*payload
.metadata
, options
, dst
, metadata_length
));
576 RETURN_NOT_OK(CheckAligned(dst
));
579 // Now write the buffers
580 for (size_t i
= 0; i
< payload
.body_buffers
.size(); ++i
) {
581 const std::shared_ptr
<Buffer
>& buffer
= payload
.body_buffers
[i
];
585 // The buffer might be null if we are handling zero row lengths.
587 size
= buffer
->size();
588 padding
= BitUtil::RoundUpToMultipleOf8(size
) - size
;
592 RETURN_NOT_OK(dst
->Write(buffer
));
596 RETURN_NOT_OK(dst
->Write(kPaddingBytes
, padding
));
601 RETURN_NOT_OK(CheckAligned(dst
));
607 Status
GetSchemaPayload(const Schema
& schema
, const IpcWriteOptions
& options
,
608 const DictionaryFieldMapper
& mapper
, IpcPayload
* out
) {
609 out
->type
= MessageType::SCHEMA
;
610 return internal::WriteSchemaMessage(schema
, mapper
, options
, &out
->metadata
);
613 Status
GetDictionaryPayload(int64_t id
, const std::shared_ptr
<Array
>& dictionary
,
614 const IpcWriteOptions
& options
, IpcPayload
* out
) {
615 return GetDictionaryPayload(id
, false, dictionary
, options
, out
);
618 Status
GetDictionaryPayload(int64_t id
, bool is_delta
,
619 const std::shared_ptr
<Array
>& dictionary
,
620 const IpcWriteOptions
& options
, IpcPayload
* out
) {
621 out
->type
= MessageType::DICTIONARY_BATCH
;
622 // Frame of reference is 0, see ARROW-384
623 DictionarySerializer
assembler(id
, is_delta
, /*buffer_start_offset=*/0, options
, out
);
624 return assembler
.Assemble(dictionary
);
627 Status
GetRecordBatchPayload(const RecordBatch
& batch
, const IpcWriteOptions
& options
,
629 out
->type
= MessageType::RECORD_BATCH
;
630 RecordBatchSerializer
assembler(/*buffer_start_offset=*/0, options
, out
);
631 return assembler
.Assemble(batch
);
634 Status
WriteRecordBatch(const RecordBatch
& batch
, int64_t buffer_start_offset
,
635 io::OutputStream
* dst
, int32_t* metadata_length
,
636 int64_t* body_length
, const IpcWriteOptions
& options
) {
638 RecordBatchSerializer
assembler(buffer_start_offset
, options
, &payload
);
639 RETURN_NOT_OK(assembler
.Assemble(batch
));
641 // TODO: it's a rough edge that the metadata and body length here are
642 // computed separately
644 // The body size is computed in the payload
645 *body_length
= payload
.body_length
;
647 return WriteIpcPayload(payload
, options
, dst
, metadata_length
);
650 Status
WriteRecordBatchStream(const std::vector
<std::shared_ptr
<RecordBatch
>>& batches
,
651 const IpcWriteOptions
& options
, io::OutputStream
* dst
) {
652 ASSIGN_OR_RAISE(std::shared_ptr
<RecordBatchWriter
> writer
,
653 MakeStreamWriter(dst
, batches
[0]->schema(), options
));
654 for (const auto& batch
: batches
) {
655 DCHECK(batch
->schema()->Equals(*batches
[0]->schema())) << "Schemas unequal";
656 RETURN_NOT_OK(writer
->WriteRecordBatch(*batch
));
658 RETURN_NOT_OK(writer
->Close());
664 Status
WriteTensorHeader(const Tensor
& tensor
, io::OutputStream
* dst
,
665 int32_t* metadata_length
) {
666 IpcWriteOptions options
;
667 options
.alignment
= kTensorAlignment
;
668 std::shared_ptr
<Buffer
> metadata
;
669 ARROW_ASSIGN_OR_RAISE(metadata
, internal::WriteTensorMessage(tensor
, 0, options
));
670 return WriteMessage(*metadata
, options
, dst
, metadata_length
);
673 Status
WriteStridedTensorData(int dim_index
, int64_t offset
, int elem_size
,
674 const Tensor
& tensor
, uint8_t* scratch_space
,
675 io::OutputStream
* dst
) {
676 if (dim_index
== tensor
.ndim() - 1) {
677 const uint8_t* data_ptr
= tensor
.raw_data() + offset
;
678 const int64_t stride
= tensor
.strides()[dim_index
];
679 for (int64_t i
= 0; i
< tensor
.shape()[dim_index
]; ++i
) {
680 memcpy(scratch_space
+ i
* elem_size
, data_ptr
, elem_size
);
683 return dst
->Write(scratch_space
, elem_size
* tensor
.shape()[dim_index
]);
685 for (int64_t i
= 0; i
< tensor
.shape()[dim_index
]; ++i
) {
686 RETURN_NOT_OK(WriteStridedTensorData(dim_index
+ 1, offset
, elem_size
, tensor
,
687 scratch_space
, dst
));
688 offset
+= tensor
.strides()[dim_index
];
693 Status
GetContiguousTensor(const Tensor
& tensor
, MemoryPool
* pool
,
694 std::unique_ptr
<Tensor
>* out
) {
695 const int elem_size
= GetByteWidth(*tensor
.type());
697 ARROW_ASSIGN_OR_RAISE(
699 AllocateBuffer(tensor
.shape()[tensor
.ndim() - 1] * elem_size
, pool
));
701 ARROW_ASSIGN_OR_RAISE(std::shared_ptr
<ResizableBuffer
> contiguous_data
,
702 AllocateResizableBuffer(tensor
.size() * elem_size
, pool
));
704 io::BufferOutputStream
stream(contiguous_data
);
705 RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size
, tensor
,
706 scratch_space
->mutable_data(), &stream
));
708 out
->reset(new Tensor(tensor
.type(), contiguous_data
, tensor
.shape()));
715 Status
WriteTensor(const Tensor
& tensor
, io::OutputStream
* dst
, int32_t* metadata_length
,
716 int64_t* body_length
) {
717 const int elem_size
= GetByteWidth(*tensor
.type());
719 *body_length
= tensor
.size() * elem_size
;
721 // Tensor metadata accounts for padding
722 if (tensor
.is_contiguous()) {
723 RETURN_NOT_OK(WriteTensorHeader(tensor
, dst
, metadata_length
));
724 auto data
= tensor
.data();
725 if (data
&& data
->data()) {
726 RETURN_NOT_OK(dst
->Write(data
->data(), *body_length
));
731 // The tensor written is made contiguous
732 Tensor
dummy(tensor
.type(), nullptr, tensor
.shape());
733 RETURN_NOT_OK(WriteTensorHeader(dummy
, dst
, metadata_length
));
735 // TODO: Do we care enough about this temporary allocation to pass in a
736 // MemoryPool to this function?
737 ARROW_ASSIGN_OR_RAISE(auto scratch_space
,
738 AllocateBuffer(tensor
.shape()[tensor
.ndim() - 1] * elem_size
));
740 RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size
, tensor
,
741 scratch_space
->mutable_data(), dst
));
747 Result
<std::unique_ptr
<Message
>> GetTensorMessage(const Tensor
& tensor
,
749 const Tensor
* tensor_to_write
= &tensor
;
750 std::unique_ptr
<Tensor
> temp_tensor
;
752 if (!tensor
.is_contiguous()) {
753 RETURN_NOT_OK(GetContiguousTensor(tensor
, pool
, &temp_tensor
));
754 tensor_to_write
= temp_tensor
.get();
757 IpcWriteOptions options
;
758 options
.alignment
= kTensorAlignment
;
759 std::shared_ptr
<Buffer
> metadata
;
760 ARROW_ASSIGN_OR_RAISE(metadata
,
761 internal::WriteTensorMessage(*tensor_to_write
, 0, options
));
762 return std::unique_ptr
<Message
>(new Message(metadata
, tensor_to_write
->data()));
767 class SparseTensorSerializer
{
769 SparseTensorSerializer(int64_t buffer_start_offset
, IpcPayload
* out
)
771 buffer_start_offset_(buffer_start_offset
),
772 options_(IpcWriteOptions::Defaults()) {}
774 ~SparseTensorSerializer() = default;
776 Status
VisitSparseIndex(const SparseIndex
& sparse_index
) {
777 switch (sparse_index
.format_id()) {
778 case SparseTensorFormat::COO
:
780 VisitSparseCOOIndex(checked_cast
<const SparseCOOIndex
&>(sparse_index
)));
783 case SparseTensorFormat::CSR
:
785 VisitSparseCSRIndex(checked_cast
<const SparseCSRIndex
&>(sparse_index
)));
788 case SparseTensorFormat::CSC
:
790 VisitSparseCSCIndex(checked_cast
<const SparseCSCIndex
&>(sparse_index
)));
793 case SparseTensorFormat::CSF
:
795 VisitSparseCSFIndex(checked_cast
<const SparseCSFIndex
&>(sparse_index
)));
799 std::stringstream ss
;
800 ss
<< "Unable to convert type: " << sparse_index
.ToString() << std::endl
;
801 return Status::NotImplemented(ss
.str());
807 Status
SerializeMetadata(const SparseTensor
& sparse_tensor
) {
808 return WriteSparseTensorMessage(sparse_tensor
, out_
->body_length
, buffer_meta_
,
810 .Value(&out_
->metadata
);
813 Status
Assemble(const SparseTensor
& sparse_tensor
) {
814 if (buffer_meta_
.size() > 0) {
815 buffer_meta_
.clear();
816 out_
->body_buffers
.clear();
819 RETURN_NOT_OK(VisitSparseIndex(*sparse_tensor
.sparse_index()));
820 out_
->body_buffers
.emplace_back(sparse_tensor
.data());
822 int64_t offset
= buffer_start_offset_
;
823 buffer_meta_
.reserve(out_
->body_buffers
.size());
825 for (size_t i
= 0; i
< out_
->body_buffers
.size(); ++i
) {
826 const Buffer
* buffer
= out_
->body_buffers
[i
].get();
827 int64_t size
= buffer
->size();
828 int64_t padding
= BitUtil::RoundUpToMultipleOf8(size
) - size
;
829 buffer_meta_
.push_back({offset
, size
+ padding
});
830 offset
+= size
+ padding
;
833 out_
->body_length
= offset
- buffer_start_offset_
;
834 DCHECK(BitUtil::IsMultipleOf8(out_
->body_length
));
836 return SerializeMetadata(sparse_tensor
);
840 Status
VisitSparseCOOIndex(const SparseCOOIndex
& sparse_index
) {
841 out_
->body_buffers
.emplace_back(sparse_index
.indices()->data());
845 Status
VisitSparseCSRIndex(const SparseCSRIndex
& sparse_index
) {
846 out_
->body_buffers
.emplace_back(sparse_index
.indptr()->data());
847 out_
->body_buffers
.emplace_back(sparse_index
.indices()->data());
851 Status
VisitSparseCSCIndex(const SparseCSCIndex
& sparse_index
) {
852 out_
->body_buffers
.emplace_back(sparse_index
.indptr()->data());
853 out_
->body_buffers
.emplace_back(sparse_index
.indices()->data());
857 Status
VisitSparseCSFIndex(const SparseCSFIndex
& sparse_index
) {
858 for (const std::shared_ptr
<arrow::Tensor
>& indptr
: sparse_index
.indptr()) {
859 out_
->body_buffers
.emplace_back(indptr
->data());
861 for (const std::shared_ptr
<arrow::Tensor
>& indices
: sparse_index
.indices()) {
862 out_
->body_buffers
.emplace_back(indices
->data());
869 std::vector
<internal::BufferMetadata
> buffer_meta_
;
870 int64_t buffer_start_offset_
;
871 IpcWriteOptions options_
;
874 } // namespace internal
876 Status
WriteSparseTensor(const SparseTensor
& sparse_tensor
, io::OutputStream
* dst
,
877 int32_t* metadata_length
, int64_t* body_length
) {
879 internal::SparseTensorSerializer
writer(0, &payload
);
880 RETURN_NOT_OK(writer
.Assemble(sparse_tensor
));
882 *body_length
= payload
.body_length
;
883 return WriteIpcPayload(payload
, IpcWriteOptions::Defaults(), dst
, metadata_length
);
886 Status
GetSparseTensorPayload(const SparseTensor
& sparse_tensor
, MemoryPool
* pool
,
888 internal::SparseTensorSerializer
writer(0, out
);
889 return writer
.Assemble(sparse_tensor
);
892 Result
<std::unique_ptr
<Message
>> GetSparseTensorMessage(const SparseTensor
& sparse_tensor
,
895 RETURN_NOT_OK(GetSparseTensorPayload(sparse_tensor
, pool
, &payload
));
896 return std::unique_ptr
<Message
>(
897 new Message(std::move(payload
.metadata
), std::move(payload
.body_buffers
[0])));
900 int64_t GetPayloadSize(const IpcPayload
& payload
, const IpcWriteOptions
& options
) {
901 const int32_t prefix_size
= options
.write_legacy_ipc_format
? 4 : 8;
902 const int32_t flatbuffer_size
= static_cast<int32_t>(payload
.metadata
->size());
903 const int32_t padded_message_length
= static_cast<int32_t>(
904 PaddedLength(flatbuffer_size
+ prefix_size
, options
.alignment
));
905 // body_length already accounts for padding
906 return payload
.body_length
+ padded_message_length
;
909 Status
GetRecordBatchSize(const RecordBatch
& batch
, int64_t* size
) {
910 return GetRecordBatchSize(batch
, IpcWriteOptions::Defaults(), size
);
913 Status
GetRecordBatchSize(const RecordBatch
& batch
, const IpcWriteOptions
& options
,
915 // emulates the behavior of Write without actually writing
916 int32_t metadata_length
= 0;
917 int64_t body_length
= 0;
918 io::MockOutputStream dst
;
920 WriteRecordBatch(batch
, 0, &dst
, &metadata_length
, &body_length
, options
));
921 *size
= dst
.GetExtentBytesWritten();
925 Status
GetTensorSize(const Tensor
& tensor
, int64_t* size
) {
926 // emulates the behavior of Write without actually writing
927 int32_t metadata_length
= 0;
928 int64_t body_length
= 0;
929 io::MockOutputStream dst
;
930 RETURN_NOT_OK(WriteTensor(tensor
, &dst
, &metadata_length
, &body_length
));
931 *size
= dst
.GetExtentBytesWritten();
935 // ----------------------------------------------------------------------
937 RecordBatchWriter::~RecordBatchWriter() {}
939 Status
RecordBatchWriter::WriteTable(const Table
& table
, int64_t max_chunksize
) {
940 TableBatchReader
reader(table
);
942 if (max_chunksize
> 0) {
943 reader
.set_chunksize(max_chunksize
);
946 std::shared_ptr
<RecordBatch
> batch
;
948 RETURN_NOT_OK(reader
.ReadNext(&batch
));
949 if (batch
== nullptr) {
952 RETURN_NOT_OK(WriteRecordBatch(*batch
));
958 Status
RecordBatchWriter::WriteTable(const Table
& table
) { return WriteTable(table
, -1); }
960 // ----------------------------------------------------------------------
961 // Payload writer implementation
965 IpcPayloadWriter::~IpcPayloadWriter() {}
967 Status
IpcPayloadWriter::Start() { return Status::OK(); }
969 class ARROW_EXPORT IpcFormatWriter
: public RecordBatchWriter
{
971 // A RecordBatchWriter implementation that writes to a IpcPayloadWriter.
972 IpcFormatWriter(std::unique_ptr
<internal::IpcPayloadWriter
> payload_writer
,
973 const Schema
& schema
, const IpcWriteOptions
& options
,
975 : payload_writer_(std::move(payload_writer
)),
978 is_file_format_(is_file_format
),
981 // A Schema-owning constructor variant
982 IpcFormatWriter(std::unique_ptr
<internal::IpcPayloadWriter
> payload_writer
,
983 const std::shared_ptr
<Schema
>& schema
, const IpcWriteOptions
& options
,
985 : IpcFormatWriter(std::move(payload_writer
), *schema
, options
, is_file_format
) {
986 shared_schema_
= schema
;
989 Status
WriteRecordBatch(const RecordBatch
& batch
) override
{
990 if (!batch
.schema()->Equals(schema_
, false /* check_metadata */)) {
991 return Status::Invalid("Tried to write record batch with different schema");
994 RETURN_NOT_OK(CheckStarted());
996 RETURN_NOT_OK(WriteDictionaries(batch
));
999 RETURN_NOT_OK(GetRecordBatchPayload(batch
, options_
, &payload
));
1000 RETURN_NOT_OK(WritePayload(payload
));
1001 ++stats_
.num_record_batches
;
1002 return Status::OK();
1005 Status
WriteTable(const Table
& table
, int64_t max_chunksize
) override
{
1006 if (is_file_format_
&& options_
.unify_dictionaries
) {
1007 ARROW_ASSIGN_OR_RAISE(auto unified_table
,
1008 DictionaryUnifier::UnifyTable(table
, options_
.memory_pool
));
1009 return RecordBatchWriter::WriteTable(*unified_table
, max_chunksize
);
1011 return RecordBatchWriter::WriteTable(table
, max_chunksize
);
1015 Status
Close() override
{
1016 RETURN_NOT_OK(CheckStarted());
1017 return payload_writer_
->Close();
1022 RETURN_NOT_OK(payload_writer_
->Start());
1025 RETURN_NOT_OK(GetSchemaPayload(schema_
, options_
, mapper_
, &payload
));
1026 return WritePayload(payload
);
1029 WriteStats
stats() const override
{ return stats_
; }
1032 Status
CheckStarted() {
1036 return Status::OK();
1039 Status
WriteDictionaries(const RecordBatch
& batch
) {
1040 ARROW_ASSIGN_OR_RAISE(const auto dictionaries
, CollectDictionaries(batch
, mapper_
));
1041 const auto equal_options
= EqualOptions().nans_equal(true);
1043 for (const auto& pair
: dictionaries
) {
1044 int64_t dictionary_id
= pair
.first
;
1045 const auto& dictionary
= pair
.second
;
1047 // If a dictionary with this id was already emitted, check if it was the same.
1048 auto* last_dictionary
= &last_dictionaries_
[dictionary_id
];
1049 const bool dictionary_exists
= (*last_dictionary
!= nullptr);
1050 int64_t delta_start
= 0;
1051 if (dictionary_exists
) {
1052 if ((*last_dictionary
)->data() == dictionary
->data()) {
1053 // Fast shortcut for a common case.
1054 // Same dictionary data by pointer => no need to emit it again
1057 const int64_t last_length
= (*last_dictionary
)->length();
1058 const int64_t new_length
= dictionary
->length();
1059 if (new_length
== last_length
&&
1060 ((*last_dictionary
)->Equals(dictionary
, equal_options
))) {
1061 // Same dictionary by value => no need to emit it again
1062 // (while this can have a CPU cost, this code path is required
1063 // for the IPC file format)
1066 if (is_file_format_
) {
1067 return Status::Invalid(
1068 "Dictionary replacement detected when writing IPC file format. "
1069 "Arrow IPC files only support a single dictionary for a given field "
1070 "across all batches.");
1073 // (the read path doesn't support outer dictionary deltas, don't emit them)
1074 if (new_length
> last_length
&& options_
.emit_dictionary_deltas
&&
1075 !HasNestedDict(*dictionary
->data()) &&
1077 ->RangeEquals(dictionary
, 0, last_length
, 0, equal_options
))) {
1078 // New dictionary starts with the current dictionary
1079 delta_start
= last_length
;
1085 RETURN_NOT_OK(GetDictionaryPayload(dictionary_id
, /*is_delta=*/true,
1086 dictionary
->Slice(delta_start
), options_
,
1090 GetDictionaryPayload(dictionary_id
, dictionary
, options_
, &payload
));
1092 RETURN_NOT_OK(WritePayload(payload
));
1093 ++stats_
.num_dictionary_batches
;
1094 if (dictionary_exists
) {
1096 ++stats_
.num_dictionary_deltas
;
1098 ++stats_
.num_replaced_dictionaries
;
1102 // Remember dictionary for next batches
1103 *last_dictionary
= dictionary
;
1105 return Status::OK();
1108 Status
WritePayload(const IpcPayload
& payload
) {
1109 RETURN_NOT_OK(payload_writer_
->WritePayload(payload
));
1110 ++stats_
.num_messages
;
1111 return Status::OK();
1114 std::unique_ptr
<IpcPayloadWriter
> payload_writer_
;
1115 std::shared_ptr
<Schema
> shared_schema_
;
1116 const Schema
& schema_
;
1117 const DictionaryFieldMapper mapper_
;
1118 const bool is_file_format_
;
1120 // A map of last-written dictionaries by id.
1121 // This is required to avoid the same dictionary again and again,
1122 // and also for correctness when writing the IPC file format
1123 // (where replacements and deltas are unsupported).
1124 // The latter is also why we can't use weak_ptr.
1125 std::unordered_map
<int64_t, std::shared_ptr
<Array
>> last_dictionaries_
;
1127 bool started_
= false;
1128 IpcWriteOptions options_
;
1132 class StreamBookKeeper
{
1134 StreamBookKeeper(const IpcWriteOptions
& options
, io::OutputStream
* sink
)
1135 : options_(options
), sink_(sink
), position_(-1) {}
1136 StreamBookKeeper(const IpcWriteOptions
& options
, std::shared_ptr
<io::OutputStream
> sink
)
1137 : options_(options
),
1139 owned_sink_(std::move(sink
)),
1142 Status
UpdatePosition() { return sink_
->Tell().Value(&position_
); }
1144 Status
UpdatePositionCheckAligned() {
1145 RETURN_NOT_OK(UpdatePosition());
1146 DCHECK_EQ(0, position_
% 8) << "Stream is not aligned";
1147 return Status::OK();
1150 Status
Align(int32_t alignment
= kArrowIpcAlignment
) {
1151 // Adds padding bytes if necessary to ensure all memory blocks are written on
1152 // 8-byte (or other alignment) boundaries.
1153 int64_t remainder
= PaddedLength(position_
, alignment
) - position_
;
1154 if (remainder
> 0) {
1155 return Write(kPaddingBytes
, remainder
);
1157 return Status::OK();
1160 // Write data and update position
1161 Status
Write(const void* data
, int64_t nbytes
) {
1162 RETURN_NOT_OK(sink_
->Write(data
, nbytes
));
1163 position_
+= nbytes
;
1164 return Status::OK();
1168 // End of stream marker
1169 constexpr int32_t kZeroLength
= 0;
1170 if (!options_
.write_legacy_ipc_format
) {
1171 RETURN_NOT_OK(Write(&kIpcContinuationToken
, sizeof(int32_t)));
1173 return Write(&kZeroLength
, sizeof(int32_t));
1177 IpcWriteOptions options_
;
1178 io::OutputStream
* sink_
;
1179 std::shared_ptr
<io::OutputStream
> owned_sink_
;
1183 /// A IpcPayloadWriter implementation that writes to an IPC stream
1184 /// (with an end-of-stream marker)
1185 class PayloadStreamWriter
: public IpcPayloadWriter
, protected StreamBookKeeper
{
1187 PayloadStreamWriter(io::OutputStream
* sink
,
1188 const IpcWriteOptions
& options
= IpcWriteOptions::Defaults())
1189 : StreamBookKeeper(options
, sink
) {}
1190 PayloadStreamWriter(std::shared_ptr
<io::OutputStream
> sink
,
1191 const IpcWriteOptions
& options
= IpcWriteOptions::Defaults())
1192 : StreamBookKeeper(options
, std::move(sink
)) {}
1194 ~PayloadStreamWriter() override
= default;
1196 Status
WritePayload(const IpcPayload
& payload
) override
{
1198 // Catch bug fixed in ARROW-3236
1199 RETURN_NOT_OK(UpdatePositionCheckAligned());
1202 int32_t metadata_length
= 0; // unused
1203 RETURN_NOT_OK(WriteIpcPayload(payload
, options_
, sink_
, &metadata_length
));
1204 RETURN_NOT_OK(UpdatePositionCheckAligned());
1205 return Status::OK();
1208 Status
Close() override
{ return WriteEOS(); }
1211 /// A IpcPayloadWriter implementation that writes to a IPC file
1212 /// (with a footer as defined in File.fbs)
1213 class PayloadFileWriter
: public internal::IpcPayloadWriter
, protected StreamBookKeeper
{
1215 PayloadFileWriter(const IpcWriteOptions
& options
, const std::shared_ptr
<Schema
>& schema
,
1216 const std::shared_ptr
<const KeyValueMetadata
>& metadata
,
1217 io::OutputStream
* sink
)
1218 : StreamBookKeeper(options
, sink
), schema_(schema
), metadata_(metadata
) {}
1219 PayloadFileWriter(const IpcWriteOptions
& options
, const std::shared_ptr
<Schema
>& schema
,
1220 const std::shared_ptr
<const KeyValueMetadata
>& metadata
,
1221 std::shared_ptr
<io::OutputStream
> sink
)
1222 : StreamBookKeeper(options
, std::move(sink
)),
1224 metadata_(metadata
) {}
1226 ~PayloadFileWriter() override
= default;
1228 Status
WritePayload(const IpcPayload
& payload
) override
{
1230 // Catch bug fixed in ARROW-3236
1231 RETURN_NOT_OK(UpdatePositionCheckAligned());
1234 // Metadata length must include padding, it's computed by WriteIpcPayload()
1235 FileBlock block
= {position_
, 0, payload
.body_length
};
1236 RETURN_NOT_OK(WriteIpcPayload(payload
, options_
, sink_
, &block
.metadata_length
));
1237 RETURN_NOT_OK(UpdatePositionCheckAligned());
1239 // Record position and size of some message types, to list them in the footer
1240 switch (payload
.type
) {
1241 case MessageType::DICTIONARY_BATCH
:
1242 dictionaries_
.push_back(block
);
1244 case MessageType::RECORD_BATCH
:
1245 record_batches_
.push_back(block
);
1251 return Status::OK();
1254 Status
Start() override
{
1255 // ARROW-3236: The initial position -1 needs to be updated to the stream's
1256 // current position otherwise an incorrect amount of padding will be
1257 // written to new files.
1258 RETURN_NOT_OK(UpdatePosition());
1260 // It is only necessary to align to 8-byte boundary at the start of the file
1261 RETURN_NOT_OK(Write(kArrowMagicBytes
, strlen(kArrowMagicBytes
)));
1262 RETURN_NOT_OK(Align());
1264 return Status::OK();
1267 Status
Close() override
{
1268 // Write 0 EOS message for compatibility with sequential readers
1269 RETURN_NOT_OK(WriteEOS());
1271 // Write file footer
1272 RETURN_NOT_OK(UpdatePosition());
1273 int64_t initial_position
= position_
;
1275 WriteFileFooter(*schema_
, dictionaries_
, record_batches_
, metadata_
, sink_
));
1277 // Write footer length
1278 RETURN_NOT_OK(UpdatePosition());
1279 int32_t footer_length
= static_cast<int32_t>(position_
- initial_position
);
1280 if (footer_length
<= 0) {
1281 return Status::Invalid("Invalid file footer");
1284 // write footer length in little endian
1285 footer_length
= BitUtil::ToLittleEndian(footer_length
);
1286 RETURN_NOT_OK(Write(&footer_length
, sizeof(int32_t)));
1288 // Write magic bytes to end file
1289 return Write(kArrowMagicBytes
, strlen(kArrowMagicBytes
));
1293 std::shared_ptr
<Schema
> schema_
;
1294 std::shared_ptr
<const KeyValueMetadata
> metadata_
;
1295 std::vector
<FileBlock
> dictionaries_
;
1296 std::vector
<FileBlock
> record_batches_
;
1299 } // namespace internal
1301 Result
<std::shared_ptr
<RecordBatchWriter
>> MakeStreamWriter(
1302 io::OutputStream
* sink
, const std::shared_ptr
<Schema
>& schema
,
1303 const IpcWriteOptions
& options
) {
1304 return std::make_shared
<internal::IpcFormatWriter
>(
1305 ::arrow::internal::make_unique
<internal::PayloadStreamWriter
>(sink
, options
),
1306 schema
, options
, /*is_file_format=*/false);
1309 Result
<std::shared_ptr
<RecordBatchWriter
>> MakeStreamWriter(
1310 std::shared_ptr
<io::OutputStream
> sink
, const std::shared_ptr
<Schema
>& schema
,
1311 const IpcWriteOptions
& options
) {
1312 return std::make_shared
<internal::IpcFormatWriter
>(
1313 ::arrow::internal::make_unique
<internal::PayloadStreamWriter
>(std::move(sink
),
1315 schema
, options
, /*is_file_format=*/false);
1318 Result
<std::shared_ptr
<RecordBatchWriter
>> NewStreamWriter(
1319 io::OutputStream
* sink
, const std::shared_ptr
<Schema
>& schema
,
1320 const IpcWriteOptions
& options
) {
1321 return MakeStreamWriter(sink
, schema
, options
);
1324 Result
<std::shared_ptr
<RecordBatchWriter
>> MakeFileWriter(
1325 io::OutputStream
* sink
, const std::shared_ptr
<Schema
>& schema
,
1326 const IpcWriteOptions
& options
,
1327 const std::shared_ptr
<const KeyValueMetadata
>& metadata
) {
1328 return std::make_shared
<internal::IpcFormatWriter
>(
1329 ::arrow::internal::make_unique
<internal::PayloadFileWriter
>(options
, schema
,
1331 schema
, options
, /*is_file_format=*/true);
1334 Result
<std::shared_ptr
<RecordBatchWriter
>> MakeFileWriter(
1335 std::shared_ptr
<io::OutputStream
> sink
, const std::shared_ptr
<Schema
>& schema
,
1336 const IpcWriteOptions
& options
,
1337 const std::shared_ptr
<const KeyValueMetadata
>& metadata
) {
1338 return std::make_shared
<internal::IpcFormatWriter
>(
1339 ::arrow::internal::make_unique
<internal::PayloadFileWriter
>(
1340 options
, schema
, metadata
, std::move(sink
)),
1341 schema
, options
, /*is_file_format=*/true);
1344 Result
<std::shared_ptr
<RecordBatchWriter
>> NewFileWriter(
1345 io::OutputStream
* sink
, const std::shared_ptr
<Schema
>& schema
,
1346 const IpcWriteOptions
& options
,
1347 const std::shared_ptr
<const KeyValueMetadata
>& metadata
) {
1348 return MakeFileWriter(sink
, schema
, options
, metadata
);
1351 namespace internal
{
1353 Result
<std::unique_ptr
<RecordBatchWriter
>> OpenRecordBatchWriter(
1354 std::unique_ptr
<IpcPayloadWriter
> sink
, const std::shared_ptr
<Schema
>& schema
,
1355 const IpcWriteOptions
& options
) {
1356 // XXX should we call Start()?
1357 return ::arrow::internal::make_unique
<internal::IpcFormatWriter
>(
1358 std::move(sink
), schema
, options
, /*is_file_format=*/false);
1361 Result
<std::unique_ptr
<IpcPayloadWriter
>> MakePayloadStreamWriter(
1362 io::OutputStream
* sink
, const IpcWriteOptions
& options
) {
1363 return ::arrow::internal::make_unique
<internal::PayloadStreamWriter
>(sink
, options
);
1366 Result
<std::unique_ptr
<IpcPayloadWriter
>> MakePayloadFileWriter(
1367 io::OutputStream
* sink
, const std::shared_ptr
<Schema
>& schema
,
1368 const IpcWriteOptions
& options
,
1369 const std::shared_ptr
<const KeyValueMetadata
>& metadata
) {
1370 return ::arrow::internal::make_unique
<internal::PayloadFileWriter
>(options
, schema
,
1374 } // namespace internal
1376 // ----------------------------------------------------------------------
1377 // Serialization public APIs
1379 Result
<std::shared_ptr
<Buffer
>> SerializeRecordBatch(const RecordBatch
& batch
,
1380 std::shared_ptr
<MemoryManager
> mm
) {
1381 auto options
= IpcWriteOptions::Defaults();
1383 RETURN_NOT_OK(GetRecordBatchSize(batch
, options
, &size
));
1384 ARROW_ASSIGN_OR_RAISE(auto buffer
, mm
->AllocateBuffer(size
));
1385 ARROW_ASSIGN_OR_RAISE(auto writer
, Buffer::GetWriter(buffer
));
1387 // XXX Should we have a helper function for getting a MemoryPool
1388 // for any MemoryManager (not only CPU)?
1390 options
.memory_pool
= checked_pointer_cast
<CPUMemoryManager
>(mm
)->pool();
1392 RETURN_NOT_OK(SerializeRecordBatch(batch
, options
, writer
.get()));
1393 RETURN_NOT_OK(writer
->Close());
1397 Result
<std::shared_ptr
<Buffer
>> SerializeRecordBatch(const RecordBatch
& batch
,
1398 const IpcWriteOptions
& options
) {
1400 RETURN_NOT_OK(GetRecordBatchSize(batch
, options
, &size
));
1401 ARROW_ASSIGN_OR_RAISE(std::shared_ptr
<Buffer
> buffer
,
1402 AllocateBuffer(size
, options
.memory_pool
));
1404 io::FixedSizeBufferWriter
stream(buffer
);
1405 RETURN_NOT_OK(SerializeRecordBatch(batch
, options
, &stream
));
1409 Status
SerializeRecordBatch(const RecordBatch
& batch
, const IpcWriteOptions
& options
,
1410 io::OutputStream
* out
) {
1411 int32_t metadata_length
= 0;
1412 int64_t body_length
= 0;
1413 return WriteRecordBatch(batch
, 0, out
, &metadata_length
, &body_length
, options
);
1416 Result
<std::shared_ptr
<Buffer
>> SerializeSchema(const Schema
& schema
, MemoryPool
* pool
) {
1417 ARROW_ASSIGN_OR_RAISE(auto stream
, io::BufferOutputStream::Create(1024, pool
));
1419 auto options
= IpcWriteOptions::Defaults();
1420 const bool is_file_format
= false; // indifferent as we don't write dictionaries
1421 internal::IpcFormatWriter
writer(
1422 ::arrow::internal::make_unique
<internal::PayloadStreamWriter
>(stream
.get()), schema
,
1423 options
, is_file_format
);
1424 RETURN_NOT_OK(writer
.Start());
1425 return stream
->Finish();
1429 } // namespace arrow