]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/rocksdb/utilities/agg_merge/agg_merge.cc
update ceph source to reef 18.1.2
[ceph.git] / ceph / src / rocksdb / utilities / agg_merge / agg_merge.cc
diff --git a/ceph/src/rocksdb/utilities/agg_merge/agg_merge.cc b/ceph/src/rocksdb/utilities/agg_merge/agg_merge.cc
new file mode 100644 (file)
index 0000000..a7eab1f
--- /dev/null
@@ -0,0 +1,238 @@
+//  Copyright (c) 2017-present, Facebook, Inc.  All rights reserved.
+//  This source code is licensed under both the GPLv2 (found in the
+//  COPYING file in the root directory) and Apache 2.0 License
+//  (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/agg_merge/agg_merge.h"
+
+#include <assert.h>
+
+#include <deque>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "port/lang.h"
+#include "port/likely.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/agg_merge.h"
+#include "rocksdb/utilities/options_type.h"
+#include "util/coding.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+static std::unordered_map<std::string, std::unique_ptr<Aggregator>> func_map;
+const std::string kUnnamedFuncName = "";
+const std::string kErrorFuncName = "kErrorFuncName";
+
+Status AddAggregator(const std::string& function_name,
+                     std::unique_ptr<Aggregator>&& agg) {
+  if (function_name == kErrorFuncName) {
+    return Status::InvalidArgument(
+        "Cannot register function name kErrorFuncName");
+  }
+  func_map.emplace(function_name, std::move(agg));
+  return Status::OK();
+}
+
+AggMergeOperator::AggMergeOperator() {}
+
+std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name,
+                                           const Slice& value) {
+  std::string result;
+  PutLengthPrefixedSlice(&result, function_name);
+  result += value.ToString();
+  return result;
+}
+
+Status EncodeAggFuncAndPayload(const Slice& function_name, const Slice& payload,
+                               std::string& output) {
+  if (function_name == kErrorFuncName) {
+    return Status::InvalidArgument("Cannot use error function name");
+  }
+  if (function_name != kUnnamedFuncName &&
+      func_map.find(function_name.ToString()) == func_map.end()) {
+    return Status::InvalidArgument("Function name not registered");
+  }
+  output = EncodeAggFuncAndPayloadNoCheck(function_name, payload);
+  return Status::OK();
+}
+
+bool ExtractAggFuncAndValue(const Slice& op, Slice& func, Slice& value) {
+  value = op;
+  return GetLengthPrefixedSlice(&value, &func);
+}
+
+bool ExtractList(const Slice& encoded_list, std::vector<Slice>& decoded_list) {
+  decoded_list.clear();
+  Slice list_slice = encoded_list;
+  Slice item;
+  while (GetLengthPrefixedSlice(&list_slice, &item)) {
+    decoded_list.push_back(item);
+  }
+  return list_slice.empty();
+}
+
+class AggMergeOperator::Accumulator {
+ public:
+  bool Add(const Slice& op, bool is_partial_aggregation) {
+    if (ignore_operands_) {
+      return true;
+    }
+    Slice my_func;
+    Slice my_value;
+    bool ret = ExtractAggFuncAndValue(op, my_func, my_value);
+    if (!ret) {
+      ignore_operands_ = true;
+      return true;
+    }
+
+    // Determine whether we need to do partial merge.
+    if (is_partial_aggregation && !my_func.empty()) {
+      auto f = func_map.find(my_func.ToString());
+      if (f == func_map.end() || !f->second->DoPartialAggregate()) {
+        return false;
+      }
+    }
+
+    if (!func_valid_) {
+      if (my_func != kUnnamedFuncName) {
+        func_ = my_func;
+        func_valid_ = true;
+      }
+    } else if (func_ != my_func) {
+      // User switched aggregation function. Need to aggregate the older
+      // one first.
+
+      // Previous aggreagion can't be done in partial merge
+      if (is_partial_aggregation) {
+        func_valid_ = false;
+        ignore_operands_ = true;
+        return false;
+      }
+
+      // We could consider stashing an iterator into the hash of aggregators
+      // to avoid repeated lookups when the aggregator doesn't change.
+      auto f = func_map.find(func_.ToString());
+      if (f == func_map.end() || !f->second->Aggregate(values_, scratch_)) {
+        func_valid_ = false;
+        ignore_operands_ = true;
+        return true;
+      }
+      std::swap(scratch_, aggregated_);
+      values_.clear();
+      values_.push_back(aggregated_);
+      func_ = my_func;
+    }
+    values_.push_back(my_value);
+    return true;
+  }
+
+  // Return false if aggregation fails.
+  // One possible reason
+  bool GetResult(std::string& result) {
+    if (!func_valid_) {
+      return false;
+    }
+    auto f = func_map.find(func_.ToString());
+    if (f == func_map.end()) {
+      return false;
+    }
+    if (!f->second->Aggregate(values_, scratch_)) {
+      return false;
+    }
+    result = EncodeAggFuncAndPayloadNoCheck(func_, scratch_);
+    return true;
+  }
+
+  void Clear() {
+    func_.clear();
+    values_.clear();
+    aggregated_.clear();
+    scratch_.clear();
+    ignore_operands_ = false;
+    func_valid_ = false;
+  }
+
+ private:
+  Slice func_;
+  std::vector<Slice> values_;
+  std::string aggregated_;
+  std::string scratch_;
+  bool ignore_operands_ = false;
+  bool func_valid_ = false;
+};
+
+// Creating and using a new Accumulator might invoke multiple malloc and is
+// expensive if it needs to be done when processing each merge operation.
+// AggMergeOperator's merge operators can be invoked concurrently by multiple
+// threads so we cannot simply create one Aggregator and reuse.
+// We use thread local instances instead.
+AggMergeOperator::Accumulator& AggMergeOperator::GetTLSAccumulator() {
+  static thread_local Accumulator tls_acc;
+  tls_acc.Clear();
+  return tls_acc;
+}
+
+void AggMergeOperator::PackAllMergeOperands(const MergeOperationInput& merge_in,
+                                            MergeOperationOutput& merge_out) {
+  merge_out.new_value = "";
+  PutLengthPrefixedSlice(&merge_out.new_value, kErrorFuncName);
+  if (merge_in.existing_value != nullptr) {
+    PutLengthPrefixedSlice(&merge_out.new_value, *merge_in.existing_value);
+  }
+  for (const Slice& op : merge_in.operand_list) {
+    PutLengthPrefixedSlice(&merge_out.new_value, op);
+  }
+}
+
+bool AggMergeOperator::FullMergeV2(const MergeOperationInput& merge_in,
+                                   MergeOperationOutput* merge_out) const {
+  Accumulator& agg = GetTLSAccumulator();
+  if (merge_in.existing_value != nullptr) {
+    agg.Add(*merge_in.existing_value, /*is_partial_aggregation=*/false);
+  }
+  for (const Slice& e : merge_in.operand_list) {
+    agg.Add(e, /*is_partial_aggregation=*/false);
+  }
+
+  bool succ = agg.GetResult(merge_out->new_value);
+  if (!succ) {
+    // If aggregation can't happen, pack all merge operands. In contrast to
+    // merge operator, we don't want to fail the DB. If users insert wrong
+    // format or call unregistered an aggregation function, we still hope
+    // the DB can continue functioning with other keys.
+    PackAllMergeOperands(merge_in, *merge_out);
+  }
+  agg.Clear();
+  return true;
+}
+
+bool AggMergeOperator::PartialMergeMulti(const Slice& /*key*/,
+                                         const std::deque<Slice>& operand_list,
+                                         std::string* new_value,
+                                         Logger* /*logger*/) const {
+  Accumulator& agg = GetTLSAccumulator();
+  bool do_aggregation = true;
+  for (const Slice& item : operand_list) {
+    do_aggregation = agg.Add(item, /*is_partial_aggregation=*/true);
+    if (!do_aggregation) {
+      break;
+    }
+  }
+  if (do_aggregation) {
+    do_aggregation = agg.GetResult(*new_value);
+  }
+  agg.Clear();
+  return do_aggregation;
+}
+
+std::shared_ptr<MergeOperator> GetAggMergeOperator() {
+  STATIC_AVOID_DESTRUCTION(std::shared_ptr<MergeOperator>, instance)
+  (std::make_shared<AggMergeOperator>());
+  assert(instance);
+  return instance;
+}
+}  // namespace ROCKSDB_NAMESPACE