]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/gandiva/tests/timed_evaluate.h
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / gandiva / tests / timed_evaluate.h
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <memory>
19 #include <vector>
20 #include "benchmark/benchmark.h"
21 #include "gandiva/arrow.h"
22 #include "gandiva/filter.h"
23 #include "gandiva/projector.h"
24 #include "gandiva/tests/generate_data.h"
25
26 #pragma once
27
28 #define THOUSAND (1024)
29 #define MILLION (1024 * 1024)
30 #define NUM_BATCHES 16
31
32 namespace gandiva {
33
34 template <typename C_TYPE>
35 std::vector<C_TYPE> GenerateData(int num_records, DataGenerator<C_TYPE>& data_generator) {
36 std::vector<C_TYPE> data;
37
38 for (int i = 0; i < num_records; i++) {
39 data.push_back(data_generator.GenerateData());
40 }
41
42 return data;
43 }
44
45 class BaseEvaluator {
46 public:
47 virtual ~BaseEvaluator() = default;
48
49 virtual Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) = 0;
50 };
51
52 class ProjectEvaluator : public BaseEvaluator {
53 public:
54 explicit ProjectEvaluator(std::shared_ptr<Projector> projector)
55 : projector_(projector) {}
56
57 Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) override {
58 arrow::ArrayVector outputs;
59 return projector_->Evaluate(batch, pool, &outputs);
60 }
61
62 private:
63 std::shared_ptr<Projector> projector_;
64 };
65
66 class FilterEvaluator : public BaseEvaluator {
67 public:
68 explicit FilterEvaluator(std::shared_ptr<Filter> filter) : filter_(filter) {}
69
70 Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) override {
71 if (selection_ == nullptr || selection_->GetMaxSlots() < batch.num_rows()) {
72 auto status = SelectionVector::MakeInt16(batch.num_rows(), pool, &selection_);
73 if (!status.ok()) {
74 return status;
75 }
76 }
77 return filter_->Evaluate(batch, selection_);
78 }
79
80 private:
81 std::shared_ptr<Filter> filter_;
82 std::shared_ptr<SelectionVector> selection_;
83 };
84
85 template <typename TYPE, typename C_TYPE>
86 Status TimedEvaluate(SchemaPtr schema, BaseEvaluator& evaluator,
87 DataGenerator<C_TYPE>& data_generator, arrow::MemoryPool* pool,
88 int num_records, int batch_size, benchmark::State& state) {
89 int num_remaining = num_records;
90 int num_fields = schema->num_fields();
91 int num_calls = 0;
92 Status status;
93
94 // Generate batches of data
95 std::shared_ptr<arrow::RecordBatch> batches[NUM_BATCHES];
96 for (int i = 0; i < NUM_BATCHES; i++) {
97 // generate data for all columns in the schema
98 std::vector<ArrayPtr> columns;
99 for (int col = 0; col < num_fields; col++) {
100 std::vector<C_TYPE> data = GenerateData<C_TYPE>(batch_size, data_generator);
101 std::vector<bool> validity(batch_size, true);
102 ArrayPtr col_data =
103 MakeArrowArray<TYPE, C_TYPE>(schema->field(col)->type(), data, validity);
104
105 columns.push_back(col_data);
106 }
107
108 // make the record batch
109 std::shared_ptr<arrow::RecordBatch> batch =
110 arrow::RecordBatch::Make(schema, batch_size, columns);
111 batches[i] = batch;
112 }
113
114 for (auto _ : state) {
115 int num_in_batch = batch_size;
116 num_remaining = num_records;
117 while (num_remaining > 0) {
118 if (batch_size > num_remaining) {
119 num_in_batch = num_remaining;
120 }
121
122 status = evaluator.Evaluate(*(batches[num_calls % NUM_BATCHES]), pool);
123 if (!status.ok()) {
124 state.SkipWithError("Evaluation of the batch failed");
125 return status;
126 }
127
128 num_calls++;
129 num_remaining -= num_in_batch;
130 }
131 }
132
133 return Status::OK();
134 }
135
136 } // namespace gandiva