]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #pragma once | |
19 | ||
20 | #include <cstdint> | |
21 | #include <memory> | |
22 | #include <string> | |
23 | #include <vector> | |
24 | ||
25 | #include "arrow/record_batch.h" | |
26 | #include "arrow/testing/gtest_util.h" | |
27 | #include "arrow/testing/visibility.h" | |
28 | #include "arrow/type.h" | |
29 | ||
30 | namespace arrow { | |
31 | ||
32 | class ARROW_TESTING_EXPORT ConstantArrayGenerator { | |
33 | public: | |
34 | /// \brief Generates a constant BooleanArray | |
35 | /// | |
36 | /// \param[in] size the size of the array to generate | |
37 | /// \param[in] value to repeat | |
38 | /// | |
39 | /// \return a generated Array | |
40 | static std::shared_ptr<arrow::Array> Boolean(int64_t size, bool value = false); | |
41 | ||
42 | /// \brief Generates a constant UInt8Array | |
43 | /// | |
44 | /// \param[in] size the size of the array to generate | |
45 | /// \param[in] value to repeat | |
46 | /// | |
47 | /// \return a generated Array | |
48 | static std::shared_ptr<arrow::Array> UInt8(int64_t size, uint8_t value = 0); | |
49 | ||
50 | /// \brief Generates a constant Int8Array | |
51 | /// | |
52 | /// \param[in] size the size of the array to generate | |
53 | /// \param[in] value to repeat | |
54 | /// | |
55 | /// \return a generated Array | |
56 | static std::shared_ptr<arrow::Array> Int8(int64_t size, int8_t value = 0); | |
57 | ||
58 | /// \brief Generates a constant UInt16Array | |
59 | /// | |
60 | /// \param[in] size the size of the array to generate | |
61 | /// \param[in] value to repeat | |
62 | /// | |
63 | /// \return a generated Array | |
64 | static std::shared_ptr<arrow::Array> UInt16(int64_t size, uint16_t value = 0); | |
65 | ||
66 | /// \brief Generates a constant UInt16Array | |
67 | /// | |
68 | /// \param[in] size the size of the array to generate | |
69 | /// \param[in] value to repeat | |
70 | /// | |
71 | /// \return a generated Array | |
72 | static std::shared_ptr<arrow::Array> Int16(int64_t size, int16_t value = 0); | |
73 | ||
74 | /// \brief Generates a constant UInt32Array | |
75 | /// | |
76 | /// \param[in] size the size of the array to generate | |
77 | /// \param[in] value to repeat | |
78 | /// | |
79 | /// \return a generated Array | |
80 | static std::shared_ptr<arrow::Array> UInt32(int64_t size, uint32_t value = 0); | |
81 | ||
82 | /// \brief Generates a constant UInt32Array | |
83 | /// | |
84 | /// \param[in] size the size of the array to generate | |
85 | /// \param[in] value to repeat | |
86 | /// | |
87 | /// \return a generated Array | |
88 | static std::shared_ptr<arrow::Array> Int32(int64_t size, int32_t value = 0); | |
89 | ||
90 | /// \brief Generates a constant UInt64Array | |
91 | /// | |
92 | /// \param[in] size the size of the array to generate | |
93 | /// \param[in] value to repeat | |
94 | /// | |
95 | /// \return a generated Array | |
96 | static std::shared_ptr<arrow::Array> UInt64(int64_t size, uint64_t value = 0); | |
97 | ||
98 | /// \brief Generates a constant UInt64Array | |
99 | /// | |
100 | /// \param[in] size the size of the array to generate | |
101 | /// \param[in] value to repeat | |
102 | /// | |
103 | /// \return a generated Array | |
104 | static std::shared_ptr<arrow::Array> Int64(int64_t size, int64_t value = 0); | |
105 | ||
106 | /// \brief Generates a constant Float32Array | |
107 | /// | |
108 | /// \param[in] size the size of the array to generate | |
109 | /// \param[in] value to repeat | |
110 | /// | |
111 | /// \return a generated Array | |
112 | static std::shared_ptr<arrow::Array> Float32(int64_t size, float value = 0); | |
113 | ||
114 | /// \brief Generates a constant Float64Array | |
115 | /// | |
116 | /// \param[in] size the size of the array to generate | |
117 | /// \param[in] value to repeat | |
118 | /// | |
119 | /// \return a generated Array | |
120 | static std::shared_ptr<arrow::Array> Float64(int64_t size, double value = 0); | |
121 | ||
122 | /// \brief Generates a constant StringArray | |
123 | /// | |
124 | /// \param[in] size the size of the array to generate | |
125 | /// \param[in] value to repeat | |
126 | /// | |
127 | /// \return a generated Array | |
128 | static std::shared_ptr<arrow::Array> String(int64_t size, std::string value = ""); | |
129 | ||
130 | template <typename ArrowType, typename CType = typename ArrowType::c_type> | |
131 | static std::shared_ptr<arrow::Array> Numeric(int64_t size, CType value = 0) { | |
132 | switch (ArrowType::type_id) { | |
133 | case Type::BOOL: | |
134 | return Boolean(size, static_cast<bool>(value)); | |
135 | case Type::UINT8: | |
136 | return UInt8(size, static_cast<uint8_t>(value)); | |
137 | case Type::INT8: | |
138 | return Int8(size, static_cast<int8_t>(value)); | |
139 | case Type::UINT16: | |
140 | return UInt16(size, static_cast<uint16_t>(value)); | |
141 | case Type::INT16: | |
142 | return Int16(size, static_cast<int16_t>(value)); | |
143 | case Type::UINT32: | |
144 | return UInt32(size, static_cast<uint32_t>(value)); | |
145 | case Type::INT32: | |
146 | return Int32(size, static_cast<int32_t>(value)); | |
147 | case Type::UINT64: | |
148 | return UInt64(size, static_cast<uint64_t>(value)); | |
149 | case Type::INT64: | |
150 | return Int64(size, static_cast<int64_t>(value)); | |
151 | case Type::FLOAT: | |
152 | return Float32(size, static_cast<float>(value)); | |
153 | case Type::DOUBLE: | |
154 | return Float64(size, static_cast<double>(value)); | |
155 | default: | |
156 | return nullptr; | |
157 | } | |
158 | } | |
159 | ||
160 | /// \brief Generates a constant Array of zeroes | |
161 | /// | |
162 | /// \param[in] size the size of the array to generate | |
163 | /// \param[in] type the type of the Array | |
164 | /// | |
165 | /// \return a generated Array | |
166 | static std::shared_ptr<arrow::Array> Zeroes(int64_t size, | |
167 | const std::shared_ptr<DataType>& type) { | |
168 | switch (type->id()) { | |
169 | case Type::NA: | |
170 | return std::make_shared<NullArray>(size); | |
171 | case Type::BOOL: | |
172 | return Boolean(size); | |
173 | case Type::UINT8: | |
174 | return UInt8(size); | |
175 | case Type::INT8: | |
176 | return Int8(size); | |
177 | case Type::UINT16: | |
178 | return UInt16(size); | |
179 | case Type::INT16: | |
180 | return Int16(size); | |
181 | case Type::UINT32: | |
182 | return UInt32(size); | |
183 | case Type::INT32: | |
184 | return Int32(size); | |
185 | case Type::UINT64: | |
186 | return UInt64(size); | |
187 | case Type::INT64: | |
188 | return Int64(size); | |
189 | case Type::TIME64: | |
190 | case Type::DATE64: | |
191 | case Type::TIMESTAMP: { | |
192 | EXPECT_OK_AND_ASSIGN(auto viewed, Int64(size)->View(type)); | |
193 | return viewed; | |
194 | } | |
195 | case Type::INTERVAL_DAY_TIME: | |
196 | case Type::INTERVAL_MONTHS: | |
197 | case Type::TIME32: | |
198 | case Type::DATE32: { | |
199 | EXPECT_OK_AND_ASSIGN(auto viewed, Int32(size)->View(type)); | |
200 | return viewed; | |
201 | } | |
202 | case Type::FLOAT: | |
203 | return Float32(size); | |
204 | case Type::DOUBLE: | |
205 | return Float64(size); | |
206 | case Type::STRING: | |
207 | return String(size); | |
208 | default: | |
209 | return nullptr; | |
210 | } | |
211 | } | |
212 | ||
213 | /// \brief Generates a RecordBatch of zeroes | |
214 | /// | |
215 | /// \param[in] size the size of the array to generate | |
216 | /// \param[in] schema to conform to | |
217 | /// | |
218 | /// This function is handy to return of RecordBatch of a desired shape. | |
219 | /// | |
220 | /// \return a generated RecordBatch | |
221 | static std::shared_ptr<arrow::RecordBatch> Zeroes( | |
222 | int64_t size, const std::shared_ptr<Schema>& schema) { | |
223 | std::vector<std::shared_ptr<Array>> arrays; | |
224 | ||
225 | for (const auto& field : schema->fields()) { | |
226 | arrays.emplace_back(Zeroes(size, field->type())); | |
227 | } | |
228 | ||
229 | return RecordBatch::Make(schema, size, arrays); | |
230 | } | |
231 | ||
232 | /// \brief Generates a RecordBatchReader by repeating a RecordBatch | |
233 | /// | |
234 | /// \param[in] n_batch the number of times it repeats batch | |
235 | /// \param[in] batch the RecordBatch to repeat | |
236 | /// | |
237 | /// \return a generated RecordBatchReader | |
238 | static std::shared_ptr<arrow::RecordBatchReader> Repeat( | |
239 | int64_t n_batch, const std::shared_ptr<RecordBatch> batch) { | |
240 | std::vector<std::shared_ptr<RecordBatch>> batches(static_cast<size_t>(n_batch), | |
241 | batch); | |
242 | return *RecordBatchReader::Make(batches); | |
243 | } | |
244 | ||
245 | /// \brief Generates a RecordBatchReader of zeroes batches | |
246 | /// | |
247 | /// \param[in] n_batch the number of RecordBatch | |
248 | /// \param[in] batch_size the size of each RecordBatch | |
249 | /// \param[in] schema to conform to | |
250 | /// | |
251 | /// \return a generated RecordBatchReader | |
252 | static std::shared_ptr<arrow::RecordBatchReader> Zeroes( | |
253 | int64_t n_batch, int64_t batch_size, const std::shared_ptr<Schema>& schema) { | |
254 | return Repeat(n_batch, Zeroes(batch_size, schema)); | |
255 | } | |
256 | }; | |
257 | ||
258 | ARROW_TESTING_EXPORT | |
259 | Result<std::shared_ptr<Array>> ScalarVectorToArray(const ScalarVector& scalars); | |
260 | ||
261 | } // namespace arrow |