]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/arrow/cpp/src/parquet/stream_writer.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / parquet / stream_writer.cc
diff --git a/ceph/src/arrow/cpp/src/parquet/stream_writer.cc b/ceph/src/arrow/cpp/src/parquet/stream_writer.cc
new file mode 100644 (file)
index 0000000..253ebf1
--- /dev/null
@@ -0,0 +1,324 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/stream_writer.h"
+
+#include <utility>
+
+namespace parquet {
+
+int64_t StreamWriter::default_row_group_size_{512 * 1024 * 1024};  // 512MB
+
+constexpr int16_t StreamWriter::kDefLevelZero;
+constexpr int16_t StreamWriter::kDefLevelOne;
+constexpr int16_t StreamWriter::kRepLevelZero;
+constexpr int64_t StreamWriter::kBatchSizeOne;
+
+StreamWriter::FixedStringView::FixedStringView(const char* data_ptr)
+    : data{data_ptr}, size{std::strlen(data_ptr)} {}
+
+StreamWriter::FixedStringView::FixedStringView(const char* data_ptr, std::size_t data_len)
+    : data{data_ptr}, size{data_len} {}
+
+StreamWriter::StreamWriter(std::unique_ptr<ParquetFileWriter> writer)
+    : file_writer_{std::move(writer)},
+      row_group_writer_{file_writer_->AppendBufferedRowGroup()} {
+  auto schema = file_writer_->schema();
+  auto group_node = schema->group_node();
+
+  nodes_.resize(schema->num_columns());
+
+  for (auto i = 0; i < schema->num_columns(); ++i) {
+    nodes_[i] = std::static_pointer_cast<schema::PrimitiveNode>(group_node->field(i));
+  }
+}
+
+void StreamWriter::SetDefaultMaxRowGroupSize(int64_t max_size) {
+  default_row_group_size_ = max_size;
+}
+
+void StreamWriter::SetMaxRowGroupSize(int64_t max_size) {
+  max_row_group_size_ = max_size;
+}
+
+int StreamWriter::num_columns() const { return static_cast<int>(nodes_.size()); }
+
+StreamWriter& StreamWriter::operator<<(bool v) {
+  CheckColumn(Type::BOOLEAN, ConvertedType::NONE);
+  return Write<BoolWriter>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(int8_t v) {
+  CheckColumn(Type::INT32, ConvertedType::INT_8);
+  return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(uint8_t v) {
+  CheckColumn(Type::INT32, ConvertedType::UINT_8);
+  return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(int16_t v) {
+  CheckColumn(Type::INT32, ConvertedType::INT_16);
+  return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(uint16_t v) {
+  CheckColumn(Type::INT32, ConvertedType::UINT_16);
+  return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(int32_t v) {
+  CheckColumn(Type::INT32, ConvertedType::INT_32);
+  return Write<Int32Writer>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(uint32_t v) {
+  CheckColumn(Type::INT32, ConvertedType::UINT_32);
+  return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(int64_t v) {
+  CheckColumn(Type::INT64, ConvertedType::INT_64);
+  return Write<Int64Writer>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(uint64_t v) {
+  CheckColumn(Type::INT64, ConvertedType::UINT_64);
+  return Write<Int64Writer>(static_cast<int64_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(const std::chrono::milliseconds& v) {
+  CheckColumn(Type::INT64, ConvertedType::TIMESTAMP_MILLIS);
+  return Write<Int64Writer>(static_cast<int64_t>(v.count()));
+}
+
+StreamWriter& StreamWriter::operator<<(const std::chrono::microseconds& v) {
+  CheckColumn(Type::INT64, ConvertedType::TIMESTAMP_MICROS);
+  return Write<Int64Writer>(static_cast<int64_t>(v.count()));
+}
+
+StreamWriter& StreamWriter::operator<<(float v) {
+  CheckColumn(Type::FLOAT, ConvertedType::NONE);
+  return Write<FloatWriter>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(double v) {
+  CheckColumn(Type::DOUBLE, ConvertedType::NONE);
+  return Write<DoubleWriter>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(char v) { return WriteFixedLength(&v, 1); }
+
+StreamWriter& StreamWriter::operator<<(FixedStringView v) {
+  return WriteFixedLength(v.data, v.size);
+}
+
+StreamWriter& StreamWriter::operator<<(const char* v) {
+  return WriteVariableLength(v, std::strlen(v));
+}
+
+StreamWriter& StreamWriter::operator<<(const std::string& v) {
+  return WriteVariableLength(v.data(), v.size());
+}
+
+StreamWriter& StreamWriter::operator<<(::arrow::util::string_view v) {
+  return WriteVariableLength(v.data(), v.size());
+}
+
+StreamWriter& StreamWriter::WriteVariableLength(const char* data_ptr,
+                                                std::size_t data_len) {
+  CheckColumn(Type::BYTE_ARRAY, ConvertedType::UTF8);
+
+  auto writer = static_cast<ByteArrayWriter*>(row_group_writer_->column(column_index_++));
+
+  if (data_ptr != nullptr) {
+    ByteArray ba_value;
+
+    ba_value.ptr = reinterpret_cast<const uint8_t*>(data_ptr);
+    ba_value.len = static_cast<uint32_t>(data_len);
+
+    writer->WriteBatch(kBatchSizeOne, &kDefLevelOne, &kRepLevelZero, &ba_value);
+  } else {
+    writer->WriteBatch(kBatchSizeOne, &kDefLevelZero, &kRepLevelZero, nullptr);
+  }
+  if (max_row_group_size_ > 0) {
+    row_group_size_ += writer->EstimatedBufferedValueBytes();
+  }
+  return *this;
+}
+
+StreamWriter& StreamWriter::WriteFixedLength(const char* data_ptr, std::size_t data_len) {
+  CheckColumn(Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::NONE,
+              static_cast<int>(data_len));
+
+  auto writer =
+      static_cast<FixedLenByteArrayWriter*>(row_group_writer_->column(column_index_++));
+
+  if (data_ptr != nullptr) {
+    FixedLenByteArray flba_value;
+
+    flba_value.ptr = reinterpret_cast<const uint8_t*>(data_ptr);
+    writer->WriteBatch(kBatchSizeOne, &kDefLevelOne, &kRepLevelZero, &flba_value);
+  } else {
+    writer->WriteBatch(kBatchSizeOne, &kDefLevelZero, &kRepLevelZero, nullptr);
+  }
+  if (max_row_group_size_ > 0) {
+    row_group_size_ += writer->EstimatedBufferedValueBytes();
+  }
+  return *this;
+}
+
+void StreamWriter::CheckColumn(Type::type physical_type,
+                               ConvertedType::type converted_type, int length) {
+  if (static_cast<std::size_t>(column_index_) >= nodes_.size()) {
+    throw ParquetException("Column index out-of-bounds.  Index " +
+                           std::to_string(column_index_) + " is invalid for " +
+                           std::to_string(nodes_.size()) + " columns");
+  }
+  const auto& node = nodes_[column_index_];
+
+  if (physical_type != node->physical_type()) {
+    throw ParquetException("Column physical type mismatch.  Column '" + node->name() +
+                           "' has physical type '" + TypeToString(node->physical_type()) +
+                           "' not '" + TypeToString(physical_type) + "'");
+  }
+  if (converted_type != node->converted_type()) {
+    throw ParquetException("Column converted type mismatch.  Column '" + node->name() +
+                           "' has converted type[" +
+                           ConvertedTypeToString(node->converted_type()) + "] not '" +
+                           ConvertedTypeToString(converted_type) + "'");
+  }
+  // Length must be exact.
+  // A shorter length fixed array is not acceptable as it would
+  // result in array bound read errors.
+  //
+  if (length != node->type_length()) {
+    throw ParquetException("Column length mismatch.  Column '" + node->name() +
+                           "' has length " + std::to_string(node->type_length()) +
+                           " not " + std::to_string(length));
+  }
+}
+
+int64_t StreamWriter::SkipColumns(int num_columns_to_skip) {
+  int num_columns_skipped = 0;
+
+  for (; (num_columns_to_skip > num_columns_skipped) &&
+         static_cast<std::size_t>(column_index_) < nodes_.size();
+       ++num_columns_skipped) {
+    const auto& node = nodes_[column_index_];
+
+    if (node->is_required()) {
+      throw ParquetException("Cannot skip column '" + node->name() +
+                             "' as it is required.");
+    }
+    auto writer = row_group_writer_->column(column_index_++);
+
+    WriteNullValue(writer);
+  }
+  return num_columns_skipped;
+}
+
+void StreamWriter::WriteNullValue(ColumnWriter* writer) {
+  switch (writer->type()) {
+    case Type::BOOLEAN:
+      static_cast<BoolWriter*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+                                                   &kRepLevelZero, nullptr);
+      break;
+    case Type::INT32:
+      static_cast<Int32Writer*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+                                                    &kRepLevelZero, nullptr);
+      break;
+    case Type::INT64:
+      static_cast<Int64Writer*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+                                                    &kRepLevelZero, nullptr);
+      break;
+    case Type::BYTE_ARRAY:
+      static_cast<ByteArrayWriter*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+                                                        &kRepLevelZero, nullptr);
+      break;
+    case Type::FIXED_LEN_BYTE_ARRAY:
+      static_cast<FixedLenByteArrayWriter*>(writer)->WriteBatch(
+          kBatchSizeOne, &kDefLevelZero, &kRepLevelZero, nullptr);
+      break;
+    case Type::FLOAT:
+      static_cast<FloatWriter*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+                                                    &kRepLevelZero, nullptr);
+      break;
+    case Type::DOUBLE:
+      static_cast<DoubleWriter*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+                                                     &kRepLevelZero, nullptr);
+      break;
+    case Type::INT96:
+    case Type::UNDEFINED:
+      throw ParquetException("Unexpected type: " + TypeToString(writer->type()));
+      break;
+  }
+}
+
+void StreamWriter::SkipOptionalColumn() {
+  if (SkipColumns(1) != 1) {
+    throw ParquetException("Failed to skip optional column at column index " +
+                           std::to_string(column_index_));
+  }
+}
+
+void StreamWriter::EndRow() {
+  if (!file_writer_) {
+    throw ParquetException("StreamWriter not initialized");
+  }
+  if (static_cast<std::size_t>(column_index_) < nodes_.size()) {
+    throw ParquetException("Cannot end row with " + std::to_string(column_index_) +
+                           " of " + std::to_string(nodes_.size()) + " columns written");
+  }
+  column_index_ = 0;
+  ++current_row_;
+
+  if (max_row_group_size_ > 0) {
+    if (row_group_size_ > max_row_group_size_) {
+      EndRowGroup();
+    }
+    // Initialize for each row with size already written
+    // (compressed + uncompressed).
+    //
+    row_group_size_ = row_group_writer_->total_bytes_written() +
+                      row_group_writer_->total_compressed_bytes();
+  }
+}
+
+void StreamWriter::EndRowGroup() {
+  if (!file_writer_) {
+    throw ParquetException("StreamWriter not initialized");
+  }
+  // Avoid creating empty row groups.
+  if (row_group_writer_->num_rows() > 0) {
+    row_group_writer_->Close();
+    row_group_writer_.reset(file_writer_->AppendBufferedRowGroup());
+  }
+}
+
+StreamWriter& operator<<(StreamWriter& os, EndRowType) {
+  os.EndRow();
+  return os;
+}
+
+StreamWriter& operator<<(StreamWriter& os, EndRowGroupType) {
+  os.EndRowGroup();
+  return os;
+}
+
+}  // namespace parquet