]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one or more |
2 | // contributor license agreements. See the NOTICE file distributed with | |
3 | // this work for additional information regarding copyright ownership. | |
4 | // The ASF licenses this file to You under the Apache License, Version 2.0 | |
5 | // (the "License"); you may not use this file except in compliance with | |
6 | // the License. You may obtain a copy of the License at | |
7 | // | |
8 | // http://www.apache.org/licenses/LICENSE-2.0 | |
9 | // | |
10 | // Unless required by applicable law or agreed to in writing, software | |
11 | // distributed under the License is distributed on an "AS IS" BASIS, | |
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
13 | // See the License for the specific language governing permissions and | |
14 | // limitations under the License. | |
15 | ||
16 | using System; | |
17 | using System.Buffers; | |
18 | using System.Buffers.Binary; | |
19 | using System.Collections.Generic; | |
20 | using System.Diagnostics; | |
21 | using System.IO; | |
22 | using System.Threading; | |
23 | using System.Threading.Tasks; | |
24 | using Apache.Arrow.Arrays; | |
25 | using Apache.Arrow.Types; | |
26 | using FlatBuffers; | |
27 | ||
28 | namespace Apache.Arrow.Ipc | |
29 | { | |
30 | public class ArrowStreamWriter : IDisposable | |
31 | { | |
32 | internal class ArrowRecordBatchFlatBufferBuilder : | |
33 | IArrowArrayVisitor<Int8Array>, | |
34 | IArrowArrayVisitor<Int16Array>, | |
35 | IArrowArrayVisitor<Int32Array>, | |
36 | IArrowArrayVisitor<Int64Array>, | |
37 | IArrowArrayVisitor<UInt8Array>, | |
38 | IArrowArrayVisitor<UInt16Array>, | |
39 | IArrowArrayVisitor<UInt32Array>, | |
40 | IArrowArrayVisitor<UInt64Array>, | |
41 | IArrowArrayVisitor<FloatArray>, | |
42 | IArrowArrayVisitor<DoubleArray>, | |
43 | IArrowArrayVisitor<BooleanArray>, | |
44 | IArrowArrayVisitor<TimestampArray>, | |
45 | IArrowArrayVisitor<Date32Array>, | |
46 | IArrowArrayVisitor<Date64Array>, | |
47 | IArrowArrayVisitor<ListArray>, | |
48 | IArrowArrayVisitor<StringArray>, | |
49 | IArrowArrayVisitor<BinaryArray>, | |
50 | IArrowArrayVisitor<FixedSizeBinaryArray>, | |
51 | IArrowArrayVisitor<StructArray>, | |
52 | IArrowArrayVisitor<Decimal128Array>, | |
53 | IArrowArrayVisitor<Decimal256Array>, | |
54 | IArrowArrayVisitor<DictionaryArray> | |
55 | { | |
56 | public readonly struct Buffer | |
57 | { | |
58 | public readonly ArrowBuffer DataBuffer; | |
59 | public readonly int Offset; | |
60 | ||
61 | public Buffer(ArrowBuffer buffer, int offset) | |
62 | { | |
63 | DataBuffer = buffer; | |
64 | Offset = offset; | |
65 | } | |
66 | } | |
67 | ||
68 | private readonly List<Buffer> _buffers; | |
69 | ||
70 | public IReadOnlyList<Buffer> Buffers => _buffers; | |
71 | ||
72 | public int TotalLength { get; private set; } | |
73 | ||
74 | public ArrowRecordBatchFlatBufferBuilder() | |
75 | { | |
76 | _buffers = new List<Buffer>(); | |
77 | TotalLength = 0; | |
78 | } | |
79 | ||
80 | public void Visit(Int8Array array) => CreateBuffers(array); | |
81 | public void Visit(Int16Array array) => CreateBuffers(array); | |
82 | public void Visit(Int32Array array) => CreateBuffers(array); | |
83 | public void Visit(Int64Array array) => CreateBuffers(array); | |
84 | public void Visit(UInt8Array array) => CreateBuffers(array); | |
85 | public void Visit(UInt16Array array) => CreateBuffers(array); | |
86 | public void Visit(UInt32Array array) => CreateBuffers(array); | |
87 | public void Visit(UInt64Array array) => CreateBuffers(array); | |
88 | public void Visit(FloatArray array) => CreateBuffers(array); | |
89 | public void Visit(DoubleArray array) => CreateBuffers(array); | |
90 | public void Visit(TimestampArray array) => CreateBuffers(array); | |
91 | public void Visit(BooleanArray array) => CreateBuffers(array); | |
92 | public void Visit(Date32Array array) => CreateBuffers(array); | |
93 | public void Visit(Date64Array array) => CreateBuffers(array); | |
94 | ||
95 | public void Visit(ListArray array) | |
96 | { | |
97 | _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); | |
98 | _buffers.Add(CreateBuffer(array.ValueOffsetsBuffer)); | |
99 | ||
100 | array.Values.Accept(this); | |
101 | } | |
102 | ||
103 | public void Visit(StringArray array) => Visit(array as BinaryArray); | |
104 | ||
105 | public void Visit(BinaryArray array) | |
106 | { | |
107 | _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); | |
108 | _buffers.Add(CreateBuffer(array.ValueOffsetsBuffer)); | |
109 | _buffers.Add(CreateBuffer(array.ValueBuffer)); | |
110 | } | |
111 | ||
112 | public void Visit(FixedSizeBinaryArray array) | |
113 | { | |
114 | _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); | |
115 | _buffers.Add(CreateBuffer(array.ValueBuffer)); | |
116 | } | |
117 | ||
118 | public void Visit(Decimal128Array array) | |
119 | { | |
120 | _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); | |
121 | _buffers.Add(CreateBuffer(array.ValueBuffer)); | |
122 | } | |
123 | ||
124 | public void Visit(Decimal256Array array) | |
125 | { | |
126 | _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); | |
127 | _buffers.Add(CreateBuffer(array.ValueBuffer)); | |
128 | } | |
129 | ||
130 | public void Visit(StructArray array) | |
131 | { | |
132 | _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); | |
133 | ||
134 | for (int i = 0; i < array.Fields.Count; i++) | |
135 | { | |
136 | array.Fields[i].Accept(this); | |
137 | } | |
138 | } | |
139 | ||
140 | public void Visit(DictionaryArray array) | |
141 | { | |
142 | // Dictionary is serialized separately in Dictionary serialization. | |
143 | // We are only interested in indices at this context. | |
144 | ||
145 | _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); | |
146 | _buffers.Add(CreateBuffer(array.IndicesBuffer)); | |
147 | } | |
148 | ||
149 | private void CreateBuffers(BooleanArray array) | |
150 | { | |
151 | _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); | |
152 | _buffers.Add(CreateBuffer(array.ValueBuffer)); | |
153 | } | |
154 | ||
155 | private void CreateBuffers<T>(PrimitiveArray<T> array) | |
156 | where T : struct | |
157 | { | |
158 | _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); | |
159 | _buffers.Add(CreateBuffer(array.ValueBuffer)); | |
160 | } | |
161 | ||
162 | private Buffer CreateBuffer(ArrowBuffer buffer) | |
163 | { | |
164 | int offset = TotalLength; | |
165 | ||
166 | int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); | |
167 | TotalLength += paddedLength; | |
168 | ||
169 | return new Buffer(buffer, offset); | |
170 | } | |
171 | ||
172 | public void Visit(IArrowArray array) | |
173 | { | |
174 | throw new NotImplementedException(); | |
175 | } | |
176 | } | |
177 | ||
178 | protected Stream BaseStream { get; } | |
179 | ||
180 | protected ArrayPool<byte> Buffers { get; } | |
181 | ||
182 | private protected FlatBufferBuilder Builder { get; } | |
183 | ||
184 | protected bool HasWrittenSchema { get; set; } | |
185 | ||
186 | private bool HasWrittenDictionaryBatch { get; set; } | |
187 | ||
188 | private bool HasWrittenStart { get; set; } | |
189 | ||
190 | private bool HasWrittenEnd { get; set; } | |
191 | ||
192 | protected Schema Schema { get; } | |
193 | ||
194 | private readonly bool _leaveOpen; | |
195 | private readonly IpcOptions _options; | |
196 | ||
197 | private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V4; | |
198 | ||
199 | private static readonly byte[] s_padding = new byte[64]; | |
200 | ||
201 | private readonly ArrowTypeFlatbufferBuilder _fieldTypeBuilder; | |
202 | ||
203 | private DictionaryMemo _dictionaryMemo; | |
204 | private DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); | |
205 | ||
206 | public ArrowStreamWriter(Stream baseStream, Schema schema) | |
207 | : this(baseStream, schema, leaveOpen: false) | |
208 | { | |
209 | } | |
210 | ||
211 | public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen) | |
212 | : this(baseStream, schema, leaveOpen, options: null) | |
213 | { | |
214 | } | |
215 | ||
216 | public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOptions options) | |
217 | { | |
218 | BaseStream = baseStream ?? throw new ArgumentNullException(nameof(baseStream)); | |
219 | Schema = schema ?? throw new ArgumentNullException(nameof(schema)); | |
220 | _leaveOpen = leaveOpen; | |
221 | ||
222 | Buffers = ArrayPool<byte>.Create(); | |
223 | Builder = new FlatBufferBuilder(1024); | |
224 | HasWrittenSchema = false; | |
225 | ||
226 | _fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder); | |
227 | _options = options ?? IpcOptions.Default; | |
228 | } | |
229 | ||
230 | ||
231 | private void CreateSelfAndChildrenFieldNodes(ArrayData data) | |
232 | { | |
233 | if (data.DataType is NestedType) | |
234 | { | |
235 | // flatbuffer struct vectors have to be created in reverse order | |
236 | for (int i = data.Children.Length - 1; i >= 0; i--) | |
237 | { | |
238 | CreateSelfAndChildrenFieldNodes(data.Children[i]); | |
239 | } | |
240 | } | |
241 | Flatbuf.FieldNode.CreateFieldNode(Builder, data.Length, data.NullCount); | |
242 | } | |
243 | ||
244 | private static int CountAllNodes(IReadOnlyDictionary<string, Field> fields) | |
245 | { | |
246 | int count = 0; | |
247 | foreach (Field arrowArray in fields.Values) | |
248 | { | |
249 | CountSelfAndChildrenNodes(arrowArray.DataType, ref count); | |
250 | } | |
251 | return count; | |
252 | } | |
253 | ||
254 | private static void CountSelfAndChildrenNodes(IArrowType type, ref int count) | |
255 | { | |
256 | if (type is NestedType nestedType) | |
257 | { | |
258 | foreach (Field childField in nestedType.Fields) | |
259 | { | |
260 | CountSelfAndChildrenNodes(childField.DataType, ref count); | |
261 | } | |
262 | } | |
263 | count++; | |
264 | } | |
265 | ||
266 | private protected void WriteRecordBatchInternal(RecordBatch recordBatch) | |
267 | { | |
268 | // TODO: Truncate buffers with extraneous padding / unused capacity | |
269 | ||
270 | if (!HasWrittenSchema) | |
271 | { | |
272 | WriteSchema(Schema); | |
273 | HasWrittenSchema = true; | |
274 | } | |
275 | ||
276 | if (!HasWrittenDictionaryBatch) | |
277 | { | |
278 | DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); | |
279 | WriteDictionaries(recordBatch); | |
280 | HasWrittenDictionaryBatch = true; | |
281 | } | |
282 | ||
283 | (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = | |
284 | PreparingWritingRecordBatch(recordBatch); | |
285 | ||
286 | VectorOffset buffersVectorOffset = Builder.EndVector(); | |
287 | ||
288 | // Serialize record batch | |
289 | ||
290 | StartingWritingRecordBatch(); | |
291 | ||
292 | Offset<Flatbuf.RecordBatch> recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, recordBatch.Length, | |
293 | fieldNodesVectorOffset, | |
294 | buffersVectorOffset); | |
295 | ||
296 | long metadataLength = WriteMessage(Flatbuf.MessageHeader.RecordBatch, | |
297 | recordBatchOffset, recordBatchBuilder.TotalLength); | |
298 | ||
299 | long bufferLength = WriteBufferData(recordBatchBuilder.Buffers); | |
300 | ||
301 | FinishedWritingRecordBatch(bufferLength, metadataLength); | |
302 | } | |
303 | ||
304 | private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, | |
305 | CancellationToken cancellationToken = default) | |
306 | { | |
307 | // TODO: Truncate buffers with extraneous padding / unused capacity | |
308 | ||
309 | if (!HasWrittenSchema) | |
310 | { | |
311 | await WriteSchemaAsync(Schema, cancellationToken).ConfigureAwait(false); | |
312 | HasWrittenSchema = true; | |
313 | } | |
314 | ||
315 | if (!HasWrittenDictionaryBatch) | |
316 | { | |
317 | DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); | |
318 | await WriteDictionariesAsync(recordBatch, cancellationToken).ConfigureAwait(false); | |
319 | HasWrittenDictionaryBatch = true; | |
320 | } | |
321 | ||
322 | (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = | |
323 | PreparingWritingRecordBatch(recordBatch); | |
324 | ||
325 | VectorOffset buffersVectorOffset = Builder.EndVector(); | |
326 | ||
327 | // Serialize record batch | |
328 | ||
329 | StartingWritingRecordBatch(); | |
330 | ||
331 | Offset<Flatbuf.RecordBatch> recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, recordBatch.Length, | |
332 | fieldNodesVectorOffset, | |
333 | buffersVectorOffset); | |
334 | ||
335 | long metadataLength = await WriteMessageAsync(Flatbuf.MessageHeader.RecordBatch, | |
336 | recordBatchOffset, recordBatchBuilder.TotalLength, | |
337 | cancellationToken).ConfigureAwait(false); | |
338 | ||
339 | long bufferLength = await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); | |
340 | ||
341 | FinishedWritingRecordBatch(bufferLength, metadataLength); | |
342 | } | |
343 | ||
344 | private long WriteBufferData(IReadOnlyList<ArrowRecordBatchFlatBufferBuilder.Buffer> buffers) | |
345 | { | |
346 | long bodyLength = 0; | |
347 | ||
348 | for (int i = 0; i < buffers.Count; i++) | |
349 | { | |
350 | ArrowBuffer buffer = buffers[i].DataBuffer; | |
351 | if (buffer.IsEmpty) | |
352 | continue; | |
353 | ||
354 | WriteBuffer(buffer); | |
355 | ||
356 | int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); | |
357 | int padding = paddedLength - buffer.Length; | |
358 | if (padding > 0) | |
359 | { | |
360 | WritePadding(padding); | |
361 | } | |
362 | ||
363 | bodyLength += paddedLength; | |
364 | } | |
365 | ||
366 | // Write padding so the record batch message body length is a multiple of 8 bytes | |
367 | ||
368 | int bodyPaddingLength = CalculatePadding(bodyLength); | |
369 | ||
370 | WritePadding(bodyPaddingLength); | |
371 | ||
372 | return bodyLength + bodyPaddingLength; | |
373 | } | |
374 | ||
375 | private async ValueTask<long> WriteBufferDataAsync(IReadOnlyList<ArrowRecordBatchFlatBufferBuilder.Buffer> buffers, CancellationToken cancellationToken = default) | |
376 | { | |
377 | long bodyLength = 0; | |
378 | ||
379 | for (int i = 0; i < buffers.Count; i++) | |
380 | { | |
381 | ArrowBuffer buffer = buffers[i].DataBuffer; | |
382 | if (buffer.IsEmpty) | |
383 | continue; | |
384 | ||
385 | await WriteBufferAsync(buffer, cancellationToken).ConfigureAwait(false); | |
386 | ||
387 | int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); | |
388 | int padding = paddedLength - buffer.Length; | |
389 | if (padding > 0) | |
390 | { | |
391 | await WritePaddingAsync(padding).ConfigureAwait(false); | |
392 | } | |
393 | ||
394 | bodyLength += paddedLength; | |
395 | } | |
396 | ||
397 | // Write padding so the record batch message body length is a multiple of 8 bytes | |
398 | ||
399 | int bodyPaddingLength = CalculatePadding(bodyLength); | |
400 | ||
401 | await WritePaddingAsync(bodyPaddingLength).ConfigureAwait(false); | |
402 | ||
403 | return bodyLength + bodyPaddingLength; | |
404 | } | |
405 | ||
406 | private Tuple<ArrowRecordBatchFlatBufferBuilder, VectorOffset> PreparingWritingRecordBatch(RecordBatch recordBatch) | |
407 | { | |
408 | return PreparingWritingRecordBatch(recordBatch.Schema.Fields, recordBatch.ArrayList); | |
409 | } | |
410 | ||
411 | private Tuple<ArrowRecordBatchFlatBufferBuilder, VectorOffset> PreparingWritingRecordBatch(IReadOnlyDictionary<string, Field> fields, IReadOnlyList<IArrowArray> arrays) | |
412 | { | |
413 | Builder.Clear(); | |
414 | ||
415 | // Serialize field nodes | |
416 | ||
417 | int fieldCount = fields.Count; | |
418 | ||
419 | Flatbuf.RecordBatch.StartNodesVector(Builder, CountAllNodes(fields)); | |
420 | ||
421 | // flatbuffer struct vectors have to be created in reverse order | |
422 | for (int i = fieldCount - 1; i >= 0; i--) | |
423 | { | |
424 | CreateSelfAndChildrenFieldNodes(arrays[i].Data); | |
425 | } | |
426 | ||
427 | VectorOffset fieldNodesVectorOffset = Builder.EndVector(); | |
428 | ||
429 | // Serialize buffers | |
430 | ||
431 | var recordBatchBuilder = new ArrowRecordBatchFlatBufferBuilder(); | |
432 | for (int i = 0; i < fieldCount; i++) | |
433 | { | |
434 | IArrowArray fieldArray = arrays[i]; | |
435 | fieldArray.Accept(recordBatchBuilder); | |
436 | } | |
437 | ||
438 | IReadOnlyList<ArrowRecordBatchFlatBufferBuilder.Buffer> buffers = recordBatchBuilder.Buffers; | |
439 | ||
440 | Flatbuf.RecordBatch.StartBuffersVector(Builder, buffers.Count); | |
441 | ||
442 | // flatbuffer struct vectors have to be created in reverse order | |
443 | for (int i = buffers.Count - 1; i >= 0; i--) | |
444 | { | |
445 | Flatbuf.Buffer.CreateBuffer(Builder, | |
446 | buffers[i].Offset, buffers[i].DataBuffer.Length); | |
447 | } | |
448 | ||
449 | return Tuple.Create(recordBatchBuilder, fieldNodesVectorOffset); | |
450 | } | |
451 | ||
452 | ||
453 | private protected void WriteDictionaries(RecordBatch recordBatch) | |
454 | { | |
455 | foreach (Field field in recordBatch.Schema.Fields.Values) | |
456 | { | |
457 | WriteDictionary(field); | |
458 | } | |
459 | } | |
460 | ||
461 | private protected void WriteDictionary(Field field) | |
462 | { | |
463 | if (field.DataType.TypeId != ArrowTypeId.Dictionary) | |
464 | { | |
465 | if (field.DataType is NestedType nestedType) | |
466 | { | |
467 | foreach (Field child in nestedType.Fields) | |
468 | { | |
469 | WriteDictionary(child); | |
470 | } | |
471 | } | |
472 | return; | |
473 | } | |
474 | ||
475 | (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset<Flatbuf.DictionaryBatch> dictionaryBatchOffset) = | |
476 | CreateDictionaryBatchOffset(field); | |
477 | ||
478 | WriteMessage(Flatbuf.MessageHeader.DictionaryBatch, | |
479 | dictionaryBatchOffset, recordBatchBuilder.TotalLength); | |
480 | ||
481 | WriteBufferData(recordBatchBuilder.Buffers); | |
482 | } | |
483 | ||
484 | private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken) | |
485 | { | |
486 | foreach (Field field in recordBatch.Schema.Fields.Values) | |
487 | { | |
488 | await WriteDictionaryAsync(field, cancellationToken).ConfigureAwait(false); | |
489 | } | |
490 | } | |
491 | ||
492 | private protected async Task WriteDictionaryAsync(Field field, CancellationToken cancellationToken) | |
493 | { | |
494 | if (field.DataType.TypeId != ArrowTypeId.Dictionary) | |
495 | { | |
496 | if (field.DataType is NestedType nestedType) | |
497 | { | |
498 | foreach (Field child in nestedType.Fields) | |
499 | { | |
500 | await WriteDictionaryAsync(child, cancellationToken).ConfigureAwait(false); | |
501 | } | |
502 | } | |
503 | return; | |
504 | } | |
505 | ||
506 | (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset<Flatbuf.DictionaryBatch> dictionaryBatchOffset) = | |
507 | CreateDictionaryBatchOffset(field); | |
508 | ||
509 | await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch, | |
510 | dictionaryBatchOffset, recordBatchBuilder.TotalLength, cancellationToken).ConfigureAwait(false); | |
511 | ||
512 | await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); | |
513 | } | |
514 | ||
515 | private Tuple<ArrowRecordBatchFlatBufferBuilder, Offset<Flatbuf.DictionaryBatch>> CreateDictionaryBatchOffset(Field field) | |
516 | { | |
517 | Field dictionaryField = new Field("dummy", ((DictionaryType)field.DataType).ValueType, false); | |
518 | long id = DictionaryMemo.GetId(field); | |
519 | IArrowArray dictionary = DictionaryMemo.GetDictionary(id); | |
520 | ||
521 | var fieldsDictionary = new Dictionary<string, Field> { | |
522 | { dictionaryField.Name, dictionaryField } }; | |
523 | ||
524 | var arrays = new List<IArrowArray> { dictionary }; | |
525 | ||
526 | (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = | |
527 | PreparingWritingRecordBatch(fieldsDictionary, arrays); | |
528 | ||
529 | VectorOffset buffersVectorOffset = Builder.EndVector(); | |
530 | ||
531 | // Serialize record batch | |
532 | Offset<Flatbuf.RecordBatch> recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, dictionary.Length, | |
533 | fieldNodesVectorOffset, | |
534 | buffersVectorOffset); | |
535 | ||
536 | // TODO: Support delta. | |
537 | Offset<Flatbuf.DictionaryBatch> dictionaryBatchOffset = Flatbuf.DictionaryBatch.CreateDictionaryBatch(Builder, id, recordBatchOffset, false); | |
538 | return Tuple.Create(recordBatchBuilder, dictionaryBatchOffset); | |
539 | } | |
540 | ||
541 | private protected virtual void WriteStartInternal() | |
542 | { | |
543 | if (!HasWrittenSchema) | |
544 | { | |
545 | WriteSchema(Schema); | |
546 | HasWrittenSchema = true; | |
547 | } | |
548 | } | |
549 | ||
550 | private protected async virtual ValueTask WriteStartInternalAsync(CancellationToken cancellationToken) | |
551 | { | |
552 | if (!HasWrittenSchema) | |
553 | { | |
554 | await WriteSchemaAsync(Schema, cancellationToken).ConfigureAwait(false); | |
555 | HasWrittenSchema = true; | |
556 | } | |
557 | } | |
558 | ||
559 | private protected virtual void WriteEndInternal() | |
560 | { | |
561 | WriteIpcMessageLength(length: 0); | |
562 | } | |
563 | ||
564 | private protected virtual ValueTask WriteEndInternalAsync(CancellationToken cancellationToken) | |
565 | { | |
566 | return WriteIpcMessageLengthAsync(length: 0, cancellationToken); | |
567 | } | |
568 | ||
569 | private protected virtual void StartingWritingRecordBatch() | |
570 | { | |
571 | } | |
572 | ||
573 | private protected virtual void FinishedWritingRecordBatch(long bodyLength, long metadataLength) | |
574 | { | |
575 | } | |
576 | ||
577 | public virtual void WriteRecordBatch(RecordBatch recordBatch) | |
578 | { | |
579 | WriteRecordBatchInternal(recordBatch); | |
580 | } | |
581 | ||
582 | public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default) | |
583 | { | |
584 | return WriteRecordBatchInternalAsync(recordBatch, cancellationToken); | |
585 | } | |
586 | ||
587 | public void WriteStart() | |
588 | { | |
589 | if (!HasWrittenStart) | |
590 | { | |
591 | WriteStartInternal(); | |
592 | HasWrittenStart = true; | |
593 | } | |
594 | } | |
595 | ||
596 | public async Task WriteStartAsync(CancellationToken cancellationToken = default) | |
597 | { | |
598 | if (!HasWrittenStart) | |
599 | { | |
600 | await WriteStartInternalAsync(cancellationToken); | |
601 | HasWrittenStart = true; | |
602 | } | |
603 | } | |
604 | ||
605 | public void WriteEnd() | |
606 | { | |
607 | if (!HasWrittenEnd) | |
608 | { | |
609 | WriteEndInternal(); | |
610 | HasWrittenEnd = true; | |
611 | } | |
612 | } | |
613 | ||
614 | public async Task WriteEndAsync(CancellationToken cancellationToken = default) | |
615 | { | |
616 | if (!HasWrittenEnd) | |
617 | { | |
618 | await WriteEndInternalAsync(cancellationToken); | |
619 | HasWrittenEnd = true; | |
620 | } | |
621 | } | |
622 | ||
623 | private void WriteBuffer(ArrowBuffer arrowBuffer) | |
624 | { | |
625 | BaseStream.Write(arrowBuffer.Memory); | |
626 | } | |
627 | ||
628 | private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default) | |
629 | { | |
630 | return BaseStream.WriteAsync(arrowBuffer.Memory, cancellationToken); | |
631 | } | |
632 | ||
633 | private protected Offset<Flatbuf.Schema> SerializeSchema(Schema schema) | |
634 | { | |
635 | // Build metadata | |
636 | VectorOffset metadataVectorOffset = default; | |
637 | if (schema.HasMetadata) | |
638 | { | |
639 | Offset<Flatbuf.KeyValue>[] metadataOffsets = GetMetadataOffsets(schema.Metadata); | |
640 | metadataVectorOffset = Flatbuf.Schema.CreateCustomMetadataVector(Builder, metadataOffsets); | |
641 | } | |
642 | ||
643 | // Build fields | |
644 | var fieldOffsets = new Offset<Flatbuf.Field>[schema.Fields.Count]; | |
645 | for (int i = 0; i < fieldOffsets.Length; i++) | |
646 | { | |
647 | Field field = schema.GetFieldByIndex(i); | |
648 | StringOffset fieldNameOffset = Builder.CreateString(field.Name); | |
649 | ArrowTypeFlatbufferBuilder.FieldType fieldType = _fieldTypeBuilder.BuildFieldType(field); | |
650 | ||
651 | VectorOffset fieldChildrenVectorOffset = GetChildrenFieldOffset(field); | |
652 | VectorOffset fieldMetadataVectorOffset = GetFieldMetadataOffset(field); | |
653 | Offset<Flatbuf.DictionaryEncoding> dictionaryOffset = GetDictionaryOffset(field); | |
654 | ||
655 | fieldOffsets[i] = Flatbuf.Field.CreateField(Builder, | |
656 | fieldNameOffset, field.IsNullable, fieldType.Type, fieldType.Offset, | |
657 | dictionaryOffset, fieldChildrenVectorOffset, fieldMetadataVectorOffset); | |
658 | } | |
659 | ||
660 | VectorOffset fieldsVectorOffset = Flatbuf.Schema.CreateFieldsVector(Builder, fieldOffsets); | |
661 | ||
662 | // Build schema | |
663 | ||
664 | Flatbuf.Endianness endianness = BitConverter.IsLittleEndian ? Flatbuf.Endianness.Little : Flatbuf.Endianness.Big; | |
665 | ||
666 | return Flatbuf.Schema.CreateSchema( | |
667 | Builder, endianness, fieldsVectorOffset, metadataVectorOffset); | |
668 | } | |
669 | ||
670 | private VectorOffset GetChildrenFieldOffset(Field field) | |
671 | { | |
672 | IArrowType targetDataType = field.DataType is DictionaryType dictionaryType ? | |
673 | dictionaryType.ValueType : | |
674 | field.DataType; | |
675 | ||
676 | if (!(targetDataType is NestedType type)) | |
677 | { | |
678 | return default; | |
679 | } | |
680 | ||
681 | int childrenCount = type.Fields.Count; | |
682 | var children = new Offset<Flatbuf.Field>[childrenCount]; | |
683 | ||
684 | for (int i = 0; i < childrenCount; i++) | |
685 | { | |
686 | Field childField = type.Fields[i]; | |
687 | StringOffset childFieldNameOffset = Builder.CreateString(childField.Name); | |
688 | ArrowTypeFlatbufferBuilder.FieldType childFieldType = _fieldTypeBuilder.BuildFieldType(childField); | |
689 | ||
690 | VectorOffset childFieldChildrenVectorOffset = GetChildrenFieldOffset(childField); | |
691 | VectorOffset childFieldMetadataVectorOffset = GetFieldMetadataOffset(childField); | |
692 | Offset<Flatbuf.DictionaryEncoding> dictionaryOffset = GetDictionaryOffset(childField); | |
693 | ||
694 | children[i] = Flatbuf.Field.CreateField(Builder, | |
695 | childFieldNameOffset, childField.IsNullable, childFieldType.Type, childFieldType.Offset, | |
696 | dictionaryOffset, childFieldChildrenVectorOffset, childFieldMetadataVectorOffset); | |
697 | } | |
698 | ||
699 | return Builder.CreateVectorOfTables(children); | |
700 | } | |
701 | ||
702 | private VectorOffset GetFieldMetadataOffset(Field field) | |
703 | { | |
704 | if (!field.HasMetadata) | |
705 | { | |
706 | return default; | |
707 | } | |
708 | ||
709 | Offset<Flatbuf.KeyValue>[] metadataOffsets = GetMetadataOffsets(field.Metadata); | |
710 | return Flatbuf.Field.CreateCustomMetadataVector(Builder, metadataOffsets); | |
711 | } | |
712 | ||
713 | private Offset<Flatbuf.DictionaryEncoding> GetDictionaryOffset(Field field) | |
714 | { | |
715 | if (field.DataType.TypeId != ArrowTypeId.Dictionary) | |
716 | { | |
717 | return default; | |
718 | } | |
719 | ||
720 | long id = DictionaryMemo.GetOrAssignId(field); | |
721 | var dicType = field.DataType as DictionaryType; | |
722 | var indexType = dicType.IndexType as NumberType; | |
723 | ||
724 | Offset<Flatbuf.Int> indexOffset = Flatbuf.Int.CreateInt(Builder, indexType.BitWidth, indexType.IsSigned); | |
725 | return Flatbuf.DictionaryEncoding.CreateDictionaryEncoding(Builder, id, indexOffset, dicType.Ordered); | |
726 | } | |
727 | ||
728 | private Offset<Flatbuf.KeyValue>[] GetMetadataOffsets(IReadOnlyDictionary<string, string> metadata) | |
729 | { | |
730 | Debug.Assert(metadata != null); | |
731 | Debug.Assert(metadata.Count > 0); | |
732 | ||
733 | Offset<Flatbuf.KeyValue>[] metadataOffsets = new Offset<Flatbuf.KeyValue>[metadata.Count]; | |
734 | int index = 0; | |
735 | foreach (KeyValuePair<string, string> metadatum in metadata) | |
736 | { | |
737 | StringOffset keyOffset = Builder.CreateString(metadatum.Key); | |
738 | StringOffset valueOffset = Builder.CreateString(metadatum.Value); | |
739 | ||
740 | metadataOffsets[index++] = Flatbuf.KeyValue.CreateKeyValue(Builder, keyOffset, valueOffset); | |
741 | } | |
742 | ||
743 | return metadataOffsets; | |
744 | } | |
745 | ||
746 | private Offset<Flatbuf.Schema> WriteSchema(Schema schema) | |
747 | { | |
748 | Builder.Clear(); | |
749 | ||
750 | // Build schema | |
751 | ||
752 | Offset<Flatbuf.Schema> schemaOffset = SerializeSchema(schema); | |
753 | ||
754 | // Build message | |
755 | ||
756 | WriteMessage(Flatbuf.MessageHeader.Schema, schemaOffset, 0); | |
757 | ||
758 | return schemaOffset; | |
759 | } | |
760 | ||
761 | private async ValueTask<Offset<Flatbuf.Schema>> WriteSchemaAsync(Schema schema, CancellationToken cancellationToken) | |
762 | { | |
763 | Builder.Clear(); | |
764 | ||
765 | // Build schema | |
766 | ||
767 | Offset<Flatbuf.Schema> schemaOffset = SerializeSchema(schema); | |
768 | ||
769 | // Build message | |
770 | ||
771 | await WriteMessageAsync(Flatbuf.MessageHeader.Schema, schemaOffset, 0, cancellationToken) | |
772 | .ConfigureAwait(false); | |
773 | ||
774 | return schemaOffset; | |
775 | } | |
776 | ||
777 | /// <summary> | |
778 | /// Writes the message to the <see cref="BaseStream"/>. | |
779 | /// </summary> | |
780 | /// <returns> | |
781 | /// The number of bytes written to the stream. | |
782 | /// </returns> | |
783 | private protected long WriteMessage<T>( | |
784 | Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength) | |
785 | where T : struct | |
786 | { | |
787 | Offset<Flatbuf.Message> messageOffset = Flatbuf.Message.CreateMessage( | |
788 | Builder, CurrentMetadataVersion, headerType, headerOffset.Value, | |
789 | bodyLength); | |
790 | ||
791 | Builder.Finish(messageOffset.Value); | |
792 | ||
793 | ReadOnlyMemory<byte> messageData = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); | |
794 | int messagePaddingLength = CalculatePadding(_options.SizeOfIpcLength + messageData.Length); | |
795 | ||
796 | WriteIpcMessageLength(messageData.Length + messagePaddingLength); | |
797 | ||
798 | BaseStream.Write(messageData); | |
799 | WritePadding(messagePaddingLength); | |
800 | ||
801 | checked | |
802 | { | |
803 | return _options.SizeOfIpcLength + messageData.Length + messagePaddingLength; | |
804 | } | |
805 | } | |
806 | ||
807 | /// <summary> | |
808 | /// Writes the message to the <see cref="BaseStream"/>. | |
809 | /// </summary> | |
810 | /// <returns> | |
811 | /// The number of bytes written to the stream. | |
812 | /// </returns> | |
813 | private protected virtual async ValueTask<long> WriteMessageAsync<T>( | |
814 | Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength, | |
815 | CancellationToken cancellationToken) | |
816 | where T : struct | |
817 | { | |
818 | Offset<Flatbuf.Message> messageOffset = Flatbuf.Message.CreateMessage( | |
819 | Builder, CurrentMetadataVersion, headerType, headerOffset.Value, | |
820 | bodyLength); | |
821 | ||
822 | Builder.Finish(messageOffset.Value); | |
823 | ||
824 | ReadOnlyMemory<byte> messageData = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); | |
825 | int messagePaddingLength = CalculatePadding(_options.SizeOfIpcLength + messageData.Length); | |
826 | ||
827 | await WriteIpcMessageLengthAsync(messageData.Length + messagePaddingLength, cancellationToken) | |
828 | .ConfigureAwait(false); | |
829 | ||
830 | await BaseStream.WriteAsync(messageData, cancellationToken).ConfigureAwait(false); | |
831 | await WritePaddingAsync(messagePaddingLength).ConfigureAwait(false); | |
832 | ||
833 | checked | |
834 | { | |
835 | return _options.SizeOfIpcLength + messageData.Length + messagePaddingLength; | |
836 | } | |
837 | } | |
838 | ||
839 | private protected void WriteFlatBuffer() | |
840 | { | |
841 | ReadOnlyMemory<byte> segment = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); | |
842 | ||
843 | BaseStream.Write(segment); | |
844 | } | |
845 | ||
846 | private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancellationToken = default) | |
847 | { | |
848 | ReadOnlyMemory<byte> segment = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset); | |
849 | ||
850 | await BaseStream.WriteAsync(segment, cancellationToken).ConfigureAwait(false); | |
851 | } | |
852 | ||
853 | private void WriteIpcMessageLength(int length) | |
854 | { | |
855 | Buffers.RentReturn(_options.SizeOfIpcLength, (buffer) => | |
856 | { | |
857 | Memory<byte> currentBufferPosition = buffer; | |
858 | if (!_options.WriteLegacyIpcFormat) | |
859 | { | |
860 | BinaryPrimitives.WriteInt32LittleEndian( | |
861 | currentBufferPosition.Span, MessageSerializer.IpcContinuationToken); | |
862 | currentBufferPosition = currentBufferPosition.Slice(sizeof(int)); | |
863 | } | |
864 | ||
865 | BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length); | |
866 | BaseStream.Write(buffer); | |
867 | }); | |
868 | } | |
869 | ||
870 | private async ValueTask WriteIpcMessageLengthAsync(int length, CancellationToken cancellationToken) | |
871 | { | |
872 | await Buffers.RentReturnAsync(_options.SizeOfIpcLength, async (buffer) => | |
873 | { | |
874 | Memory<byte> currentBufferPosition = buffer; | |
875 | if (!_options.WriteLegacyIpcFormat) | |
876 | { | |
877 | BinaryPrimitives.WriteInt32LittleEndian( | |
878 | currentBufferPosition.Span, MessageSerializer.IpcContinuationToken); | |
879 | currentBufferPosition = currentBufferPosition.Slice(sizeof(int)); | |
880 | } | |
881 | ||
882 | BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length); | |
883 | await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); | |
884 | }).ConfigureAwait(false); | |
885 | } | |
886 | ||
887 | protected int CalculatePadding(long offset, int alignment = 8) | |
888 | { | |
889 | long result = BitUtility.RoundUpToMultiplePowerOfTwo(offset, alignment) - offset; | |
890 | checked | |
891 | { | |
892 | return (int)result; | |
893 | } | |
894 | } | |
895 | ||
896 | private protected void WritePadding(int length) | |
897 | { | |
898 | if (length > 0) | |
899 | { | |
900 | BaseStream.Write(s_padding.AsMemory(0, Math.Min(s_padding.Length, length))); | |
901 | } | |
902 | } | |
903 | ||
904 | private protected ValueTask WritePaddingAsync(int length) | |
905 | { | |
906 | if (length > 0) | |
907 | { | |
908 | return BaseStream.WriteAsync(s_padding.AsMemory(0, Math.Min(s_padding.Length, length))); | |
909 | } | |
910 | ||
911 | return default; | |
912 | } | |
913 | ||
914 | public virtual void Dispose() | |
915 | { | |
916 | if (!_leaveOpen) | |
917 | { | |
918 | BaseStream.Dispose(); | |
919 | } | |
920 | } | |
921 | } | |
922 | ||
923 | internal static class DictionaryCollector | |
924 | { | |
925 | internal static void Collect(RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo) | |
926 | { | |
927 | Schema schema = recordBatch.Schema; | |
928 | for (int i = 0; i < schema.Fields.Count; i++) | |
929 | { | |
930 | Field field = schema.GetFieldByIndex(i); | |
931 | IArrowArray array = recordBatch.Column(i); | |
932 | ||
933 | CollectDictionary(field, array.Data, ref dictionaryMemo); | |
934 | } | |
935 | } | |
936 | ||
937 | private static void CollectDictionary(Field field, ArrayData arrayData, ref DictionaryMemo dictionaryMemo) | |
938 | { | |
939 | if (field.DataType is DictionaryType dictionaryType) | |
940 | { | |
941 | if (arrayData.Dictionary == null) | |
942 | { | |
943 | throw new ArgumentException($"{nameof(arrayData.Dictionary)} must not be null"); | |
944 | } | |
945 | arrayData.Dictionary.EnsureDataType(dictionaryType.ValueType.TypeId); | |
946 | ||
947 | IArrowArray dictionary = ArrowArrayFactory.BuildArray(arrayData.Dictionary); | |
948 | ||
949 | dictionaryMemo ??= new DictionaryMemo(); | |
950 | long id = dictionaryMemo.GetOrAssignId(field); | |
951 | ||
952 | dictionaryMemo.AddOrReplaceDictionary(id, dictionary); | |
953 | WalkChildren(dictionary.Data, ref dictionaryMemo); | |
954 | } | |
955 | else | |
956 | { | |
957 | WalkChildren(arrayData, ref dictionaryMemo); | |
958 | } | |
959 | } | |
960 | ||
961 | private static void WalkChildren(ArrayData arrayData, ref DictionaryMemo dictionaryMemo) | |
962 | { | |
963 | ArrayData[] children = arrayData.Children; | |
964 | ||
965 | if (children == null) | |
966 | { | |
967 | return; | |
968 | } | |
969 | ||
970 | if (arrayData.DataType is NestedType nestedType) | |
971 | { | |
972 | for (int i = 0; i < nestedType.Fields.Count; i++) | |
973 | { | |
974 | Field childField = nestedType.Fields[i]; | |
975 | ArrayData child = children[i]; | |
976 | ||
977 | CollectDictionary(childField, child, ref dictionaryMemo); | |
978 | } | |
979 | } | |
980 | } | |
981 | } | |
982 | } |