]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #pragma once | |
19 | ||
20 | #include <cmath> | |
21 | #include <utility> | |
22 | ||
23 | #include "arrow/compute/api_aggregate.h" | |
24 | #include "arrow/compute/kernels/aggregate_internal.h" | |
25 | #include "arrow/compute/kernels/codegen_internal.h" | |
26 | #include "arrow/compute/kernels/common.h" | |
27 | #include "arrow/util/align_util.h" | |
28 | #include "arrow/util/bit_block_counter.h" | |
29 | #include "arrow/util/decimal.h" | |
30 | ||
31 | namespace arrow { | |
32 | namespace compute { | |
33 | namespace internal { | |
34 | ||
35 | void AddBasicAggKernels(KernelInit init, | |
36 | const std::vector<std::shared_ptr<DataType>>& types, | |
37 | std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func, | |
38 | SimdLevel::type simd_level = SimdLevel::NONE); | |
39 | ||
40 | void AddMinMaxKernels(KernelInit init, | |
41 | const std::vector<std::shared_ptr<DataType>>& types, | |
42 | ScalarAggregateFunction* func, | |
43 | SimdLevel::type simd_level = SimdLevel::NONE); | |
44 | void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id, | |
45 | ScalarAggregateFunction* func, | |
46 | SimdLevel::type simd_level = SimdLevel::NONE); | |
47 | ||
48 | // SIMD variants for kernels | |
49 | void AddSumAvx2AggKernels(ScalarAggregateFunction* func); | |
50 | void AddMeanAvx2AggKernels(ScalarAggregateFunction* func); | |
51 | void AddMinMaxAvx2AggKernels(ScalarAggregateFunction* func); | |
52 | ||
53 | void AddSumAvx512AggKernels(ScalarAggregateFunction* func); | |
54 | void AddMeanAvx512AggKernels(ScalarAggregateFunction* func); | |
55 | void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func); | |
56 | ||
57 | // ---------------------------------------------------------------------- | |
58 | // Sum implementation | |
59 | ||
60 | template <typename ArrowType, SimdLevel::type SimdLevel> | |
61 | struct SumImpl : public ScalarAggregator { | |
62 | using ThisType = SumImpl<ArrowType, SimdLevel>; | |
63 | using CType = typename TypeTraits<ArrowType>::CType; | |
64 | using SumType = typename FindAccumulatorType<ArrowType>::Type; | |
65 | using SumCType = typename TypeTraits<SumType>::CType; | |
66 | using OutputType = typename TypeTraits<SumType>::ScalarType; | |
67 | ||
68 | SumImpl(const std::shared_ptr<DataType>& out_type, | |
69 | const ScalarAggregateOptions& options_) | |
70 | : out_type(out_type), options(options_) {} | |
71 | ||
72 | Status Consume(KernelContext*, const ExecBatch& batch) override { | |
73 | if (batch[0].is_array()) { | |
74 | const auto& data = batch[0].array(); | |
75 | this->count += data->length - data->GetNullCount(); | |
76 | this->nulls_observed = this->nulls_observed || data->GetNullCount(); | |
77 | ||
78 | if (!options.skip_nulls && this->nulls_observed) { | |
79 | // Short-circuit | |
80 | return Status::OK(); | |
81 | } | |
82 | ||
83 | if (is_boolean_type<ArrowType>::value) { | |
84 | this->sum += static_cast<SumCType>(BooleanArray(data).true_count()); | |
85 | } else { | |
86 | this->sum += SumArray<CType, SumCType, SimdLevel>(*data); | |
87 | } | |
88 | } else { | |
89 | const auto& data = *batch[0].scalar(); | |
90 | this->count += data.is_valid * batch.length; | |
91 | this->nulls_observed = this->nulls_observed || !data.is_valid; | |
92 | if (data.is_valid) { | |
93 | this->sum += internal::UnboxScalar<ArrowType>::Unbox(data) * batch.length; | |
94 | } | |
95 | } | |
96 | return Status::OK(); | |
97 | } | |
98 | ||
99 | Status MergeFrom(KernelContext*, KernelState&& src) override { | |
100 | const auto& other = checked_cast<const ThisType&>(src); | |
101 | this->count += other.count; | |
102 | this->sum += other.sum; | |
103 | this->nulls_observed = this->nulls_observed || other.nulls_observed; | |
104 | return Status::OK(); | |
105 | } | |
106 | ||
107 | Status Finalize(KernelContext*, Datum* out) override { | |
108 | if ((!options.skip_nulls && this->nulls_observed) || | |
109 | (this->count < options.min_count)) { | |
110 | out->value = std::make_shared<OutputType>(out_type); | |
111 | } else { | |
112 | out->value = std::make_shared<OutputType>(this->sum, out_type); | |
113 | } | |
114 | return Status::OK(); | |
115 | } | |
116 | ||
117 | size_t count = 0; | |
118 | bool nulls_observed = false; | |
119 | SumCType sum = 0; | |
120 | std::shared_ptr<DataType> out_type; | |
121 | ScalarAggregateOptions options; | |
122 | }; | |
123 | ||
124 | template <typename ArrowType, SimdLevel::type SimdLevel> | |
125 | struct MeanImpl : public SumImpl<ArrowType, SimdLevel> { | |
126 | using SumImpl<ArrowType, SimdLevel>::SumImpl; | |
127 | ||
128 | template <typename T = ArrowType> | |
129 | enable_if_decimal<T, Status> FinalizeImpl(Datum* out) { | |
130 | using SumCType = typename SumImpl<ArrowType, SimdLevel>::SumCType; | |
131 | using OutputType = typename SumImpl<ArrowType, SimdLevel>::OutputType; | |
132 | if ((!options.skip_nulls && this->nulls_observed) || | |
133 | (this->count < options.min_count) || (this->count == 0)) { | |
134 | out->value = std::make_shared<OutputType>(this->out_type); | |
135 | } else { | |
136 | const SumCType mean = this->sum / this->count; | |
137 | out->value = std::make_shared<OutputType>(mean, this->out_type); | |
138 | } | |
139 | return Status::OK(); | |
140 | } | |
141 | template <typename T = ArrowType> | |
142 | enable_if_t<!is_decimal_type<T>::value, Status> FinalizeImpl(Datum* out) { | |
143 | if ((!options.skip_nulls && this->nulls_observed) || | |
144 | (this->count < options.min_count)) { | |
145 | out->value = std::make_shared<DoubleScalar>(); | |
146 | } else { | |
147 | const double mean = static_cast<double>(this->sum) / this->count; | |
148 | out->value = std::make_shared<DoubleScalar>(mean); | |
149 | } | |
150 | return Status::OK(); | |
151 | } | |
152 | Status Finalize(KernelContext*, Datum* out) override { return FinalizeImpl(out); } | |
153 | ||
154 | using SumImpl<ArrowType, SimdLevel>::options; | |
155 | }; | |
156 | ||
157 | template <template <typename> class KernelClass> | |
158 | struct SumLikeInit { | |
159 | std::unique_ptr<KernelState> state; | |
160 | KernelContext* ctx; | |
161 | const std::shared_ptr<DataType> type; | |
162 | const ScalarAggregateOptions& options; | |
163 | ||
164 | SumLikeInit(KernelContext* ctx, const std::shared_ptr<DataType>& type, | |
165 | const ScalarAggregateOptions& options) | |
166 | : ctx(ctx), type(type), options(options) {} | |
167 | ||
168 | Status Visit(const DataType&) { return Status::NotImplemented("No sum implemented"); } | |
169 | ||
170 | Status Visit(const HalfFloatType&) { | |
171 | return Status::NotImplemented("No sum implemented"); | |
172 | } | |
173 | ||
174 | Status Visit(const BooleanType&) { | |
175 | auto ty = TypeTraits<typename KernelClass<BooleanType>::SumType>::type_singleton(); | |
176 | state.reset(new KernelClass<BooleanType>(ty, options)); | |
177 | return Status::OK(); | |
178 | } | |
179 | ||
180 | template <typename Type> | |
181 | enable_if_number<Type, Status> Visit(const Type&) { | |
182 | auto ty = TypeTraits<typename KernelClass<Type>::SumType>::type_singleton(); | |
183 | state.reset(new KernelClass<Type>(ty, options)); | |
184 | return Status::OK(); | |
185 | } | |
186 | ||
187 | template <typename Type> | |
188 | enable_if_decimal<Type, Status> Visit(const Type&) { | |
189 | state.reset(new KernelClass<Type>(type, options)); | |
190 | return Status::OK(); | |
191 | } | |
192 | ||
193 | Result<std::unique_ptr<KernelState>> Create() { | |
194 | RETURN_NOT_OK(VisitTypeInline(*type, this)); | |
195 | return std::move(state); | |
196 | } | |
197 | }; | |
198 | ||
199 | // ---------------------------------------------------------------------- | |
200 | // MinMax implementation | |
201 | ||
202 | template <typename ArrowType, SimdLevel::type SimdLevel, typename Enable = void> | |
203 | struct MinMaxState {}; | |
204 | ||
205 | template <typename ArrowType, SimdLevel::type SimdLevel> | |
206 | struct MinMaxState<ArrowType, SimdLevel, enable_if_boolean<ArrowType>> { | |
207 | using ThisType = MinMaxState<ArrowType, SimdLevel>; | |
208 | using T = typename ArrowType::c_type; | |
209 | ||
210 | ThisType& operator+=(const ThisType& rhs) { | |
211 | this->has_nulls |= rhs.has_nulls; | |
212 | this->min = this->min && rhs.min; | |
213 | this->max = this->max || rhs.max; | |
214 | return *this; | |
215 | } | |
216 | ||
217 | void MergeOne(T value) { | |
218 | this->min = this->min && value; | |
219 | this->max = this->max || value; | |
220 | } | |
221 | ||
222 | T min = true; | |
223 | T max = false; | |
224 | bool has_nulls = false; | |
225 | }; | |
226 | ||
227 | template <typename ArrowType, SimdLevel::type SimdLevel> | |
228 | struct MinMaxState<ArrowType, SimdLevel, enable_if_integer<ArrowType>> { | |
229 | using ThisType = MinMaxState<ArrowType, SimdLevel>; | |
230 | using T = typename ArrowType::c_type; | |
231 | using ScalarType = typename TypeTraits<ArrowType>::ScalarType; | |
232 | ||
233 | ThisType& operator+=(const ThisType& rhs) { | |
234 | this->has_nulls |= rhs.has_nulls; | |
235 | this->min = std::min(this->min, rhs.min); | |
236 | this->max = std::max(this->max, rhs.max); | |
237 | return *this; | |
238 | } | |
239 | ||
240 | void MergeOne(T value) { | |
241 | this->min = std::min(this->min, value); | |
242 | this->max = std::max(this->max, value); | |
243 | } | |
244 | ||
245 | T min = std::numeric_limits<T>::max(); | |
246 | T max = std::numeric_limits<T>::min(); | |
247 | bool has_nulls = false; | |
248 | }; | |
249 | ||
250 | template <typename ArrowType, SimdLevel::type SimdLevel> | |
251 | struct MinMaxState<ArrowType, SimdLevel, enable_if_floating_point<ArrowType>> { | |
252 | using ThisType = MinMaxState<ArrowType, SimdLevel>; | |
253 | using T = typename ArrowType::c_type; | |
254 | using ScalarType = typename TypeTraits<ArrowType>::ScalarType; | |
255 | ||
256 | ThisType& operator+=(const ThisType& rhs) { | |
257 | this->has_nulls |= rhs.has_nulls; | |
258 | this->min = std::fmin(this->min, rhs.min); | |
259 | this->max = std::fmax(this->max, rhs.max); | |
260 | return *this; | |
261 | } | |
262 | ||
263 | void MergeOne(T value) { | |
264 | this->min = std::fmin(this->min, value); | |
265 | this->max = std::fmax(this->max, value); | |
266 | } | |
267 | ||
268 | T min = std::numeric_limits<T>::infinity(); | |
269 | T max = -std::numeric_limits<T>::infinity(); | |
270 | bool has_nulls = false; | |
271 | }; | |
272 | ||
273 | template <typename ArrowType, SimdLevel::type SimdLevel> | |
274 | struct MinMaxState<ArrowType, SimdLevel, enable_if_decimal<ArrowType>> { | |
275 | using ThisType = MinMaxState<ArrowType, SimdLevel>; | |
276 | using T = typename TypeTraits<ArrowType>::CType; | |
277 | using ScalarType = typename TypeTraits<ArrowType>::ScalarType; | |
278 | ||
279 | MinMaxState() : min(T::GetMaxSentinel()), max(T::GetMinSentinel()) {} | |
280 | ||
281 | ThisType& operator+=(const ThisType& rhs) { | |
282 | this->has_nulls |= rhs.has_nulls; | |
283 | this->min = std::min(this->min, rhs.min); | |
284 | this->max = std::max(this->max, rhs.max); | |
285 | return *this; | |
286 | } | |
287 | ||
288 | void MergeOne(util::string_view value) { | |
289 | MergeOne(T(reinterpret_cast<const uint8_t*>(value.data()))); | |
290 | } | |
291 | ||
292 | void MergeOne(const T value) { | |
293 | this->min = std::min(this->min, value); | |
294 | this->max = std::max(this->max, value); | |
295 | } | |
296 | ||
297 | T min; | |
298 | T max; | |
299 | bool has_nulls = false; | |
300 | }; | |
301 | ||
302 | template <typename ArrowType, SimdLevel::type SimdLevel> | |
303 | struct MinMaxState<ArrowType, SimdLevel, | |
304 | enable_if_t<is_base_binary_type<ArrowType>::value || | |
305 | std::is_same<ArrowType, FixedSizeBinaryType>::value>> { | |
306 | using ThisType = MinMaxState<ArrowType, SimdLevel>; | |
307 | using ScalarType = typename TypeTraits<ArrowType>::ScalarType; | |
308 | ||
309 | ThisType& operator+=(const ThisType& rhs) { | |
310 | if (!this->seen && rhs.seen) { | |
311 | this->min = rhs.min; | |
312 | this->max = rhs.max; | |
313 | } else if (this->seen && rhs.seen) { | |
314 | if (this->min > rhs.min) { | |
315 | this->min = rhs.min; | |
316 | } | |
317 | if (this->max < rhs.max) { | |
318 | this->max = rhs.max; | |
319 | } | |
320 | } | |
321 | this->has_nulls |= rhs.has_nulls; | |
322 | this->seen |= rhs.seen; | |
323 | return *this; | |
324 | } | |
325 | ||
326 | void MergeOne(util::string_view value) { | |
327 | if (!seen) { | |
328 | this->min = std::string(value); | |
329 | this->max = std::string(value); | |
330 | } else { | |
331 | if (value < util::string_view(this->min)) { | |
332 | this->min = std::string(value); | |
333 | } else if (value > util::string_view(this->max)) { | |
334 | this->max = std::string(value); | |
335 | } | |
336 | } | |
337 | this->seen = true; | |
338 | } | |
339 | ||
340 | std::string min; | |
341 | std::string max; | |
342 | bool has_nulls = false; | |
343 | bool seen = false; | |
344 | }; | |
345 | ||
346 | template <typename ArrowType, SimdLevel::type SimdLevel> | |
347 | struct MinMaxImpl : public ScalarAggregator { | |
348 | using ArrayType = typename TypeTraits<ArrowType>::ArrayType; | |
349 | using ThisType = MinMaxImpl<ArrowType, SimdLevel>; | |
350 | using StateType = MinMaxState<ArrowType, SimdLevel>; | |
351 | ||
352 | MinMaxImpl(std::shared_ptr<DataType> out_type, ScalarAggregateOptions options) | |
353 | : out_type(std::move(out_type)), options(std::move(options)), count(0) { | |
354 | this->options.min_count = std::max<uint32_t>(1, this->options.min_count); | |
355 | } | |
356 | ||
357 | Status Consume(KernelContext*, const ExecBatch& batch) override { | |
358 | if (batch[0].is_array()) { | |
359 | return ConsumeArray(ArrayType(batch[0].array())); | |
360 | } | |
361 | return ConsumeScalar(*batch[0].scalar()); | |
362 | } | |
363 | ||
364 | Status ConsumeScalar(const Scalar& scalar) { | |
365 | StateType local; | |
366 | local.has_nulls = !scalar.is_valid; | |
367 | this->count += scalar.is_valid; | |
368 | ||
369 | if (local.has_nulls && !options.skip_nulls) { | |
370 | this->state = local; | |
371 | return Status::OK(); | |
372 | } | |
373 | ||
374 | local.MergeOne(internal::UnboxScalar<ArrowType>::Unbox(scalar)); | |
375 | this->state = local; | |
376 | return Status::OK(); | |
377 | } | |
378 | ||
379 | Status ConsumeArray(const ArrayType& arr) { | |
380 | StateType local; | |
381 | ||
382 | const auto null_count = arr.null_count(); | |
383 | local.has_nulls = null_count > 0; | |
384 | this->count += arr.length() - null_count; | |
385 | ||
386 | if (local.has_nulls && !options.skip_nulls) { | |
387 | this->state = local; | |
388 | return Status::OK(); | |
389 | } | |
390 | ||
391 | if (local.has_nulls) { | |
392 | local += ConsumeWithNulls(arr); | |
393 | } else { // All true values | |
394 | for (int64_t i = 0; i < arr.length(); i++) { | |
395 | local.MergeOne(arr.GetView(i)); | |
396 | } | |
397 | } | |
398 | this->state = local; | |
399 | return Status::OK(); | |
400 | } | |
401 | ||
402 | Status MergeFrom(KernelContext*, KernelState&& src) override { | |
403 | const auto& other = checked_cast<const ThisType&>(src); | |
404 | this->state += other.state; | |
405 | this->count += other.count; | |
406 | return Status::OK(); | |
407 | } | |
408 | ||
409 | Status Finalize(KernelContext*, Datum* out) override { | |
410 | const auto& struct_type = checked_cast<const StructType&>(*out_type); | |
411 | const auto& child_type = struct_type.field(0)->type(); | |
412 | ||
413 | std::vector<std::shared_ptr<Scalar>> values; | |
414 | // Physical type != result type | |
415 | if ((state.has_nulls && !options.skip_nulls) || (this->count < options.min_count)) { | |
416 | // (null, null) | |
417 | auto null_scalar = MakeNullScalar(child_type); | |
418 | values = {null_scalar, null_scalar}; | |
419 | } else { | |
420 | ARROW_ASSIGN_OR_RAISE(auto min_scalar, | |
421 | MakeScalar(child_type, std::move(state.min))); | |
422 | ARROW_ASSIGN_OR_RAISE(auto max_scalar, | |
423 | MakeScalar(child_type, std::move(state.max))); | |
424 | values = {std::move(min_scalar), std::move(max_scalar)}; | |
425 | } | |
426 | out->value = std::make_shared<StructScalar>(std::move(values), this->out_type); | |
427 | return Status::OK(); | |
428 | } | |
429 | ||
430 | std::shared_ptr<DataType> out_type; | |
431 | ScalarAggregateOptions options; | |
432 | int64_t count; | |
433 | MinMaxState<ArrowType, SimdLevel> state; | |
434 | ||
435 | private: | |
436 | StateType ConsumeWithNulls(const ArrayType& arr) const { | |
437 | StateType local; | |
438 | const int64_t length = arr.length(); | |
439 | int64_t offset = arr.offset(); | |
440 | const uint8_t* bitmap = arr.null_bitmap_data(); | |
441 | int64_t idx = 0; | |
442 | ||
443 | const auto p = arrow::internal::BitmapWordAlign<1>(bitmap, offset, length); | |
444 | // First handle the leading bits | |
445 | const int64_t leading_bits = p.leading_bits; | |
446 | while (idx < leading_bits) { | |
447 | if (BitUtil::GetBit(bitmap, offset)) { | |
448 | local.MergeOne(arr.GetView(idx)); | |
449 | } | |
450 | idx++; | |
451 | offset++; | |
452 | } | |
453 | ||
454 | // The aligned parts scanned with BitBlockCounter | |
455 | arrow::internal::BitBlockCounter data_counter(bitmap, offset, length - leading_bits); | |
456 | auto current_block = data_counter.NextWord(); | |
457 | while (idx < length) { | |
458 | if (current_block.AllSet()) { // All true values | |
459 | int run_length = 0; | |
460 | // Scan forward until a block that has some false values (or the end) | |
461 | while (current_block.length > 0 && current_block.AllSet()) { | |
462 | run_length += current_block.length; | |
463 | current_block = data_counter.NextWord(); | |
464 | } | |
465 | for (int64_t i = 0; i < run_length; i++) { | |
466 | local.MergeOne(arr.GetView(idx + i)); | |
467 | } | |
468 | idx += run_length; | |
469 | offset += run_length; | |
470 | // The current_block already computed, advance to next loop | |
471 | continue; | |
472 | } else if (!current_block.NoneSet()) { // Some values are null | |
473 | BitmapReader reader(arr.null_bitmap_data(), offset, current_block.length); | |
474 | for (int64_t i = 0; i < current_block.length; i++) { | |
475 | if (reader.IsSet()) { | |
476 | local.MergeOne(arr.GetView(idx + i)); | |
477 | } | |
478 | reader.Next(); | |
479 | } | |
480 | ||
481 | idx += current_block.length; | |
482 | offset += current_block.length; | |
483 | } else { // All null values | |
484 | idx += current_block.length; | |
485 | offset += current_block.length; | |
486 | } | |
487 | current_block = data_counter.NextWord(); | |
488 | } | |
489 | ||
490 | return local; | |
491 | } | |
492 | }; | |
493 | ||
494 | template <SimdLevel::type SimdLevel> | |
495 | struct BooleanMinMaxImpl : public MinMaxImpl<BooleanType, SimdLevel> { | |
496 | using StateType = MinMaxState<BooleanType, SimdLevel>; | |
497 | using ArrayType = typename TypeTraits<BooleanType>::ArrayType; | |
498 | using MinMaxImpl<BooleanType, SimdLevel>::MinMaxImpl; | |
499 | using MinMaxImpl<BooleanType, SimdLevel>::options; | |
500 | ||
501 | Status Consume(KernelContext*, const ExecBatch& batch) override { | |
502 | if (ARROW_PREDICT_FALSE(batch[0].is_scalar())) { | |
503 | return ConsumeScalar(checked_cast<const BooleanScalar&>(*batch[0].scalar())); | |
504 | } | |
505 | StateType local; | |
506 | ArrayType arr(batch[0].array()); | |
507 | ||
508 | const auto arr_length = arr.length(); | |
509 | const auto null_count = arr.null_count(); | |
510 | const auto valid_count = arr_length - null_count; | |
511 | ||
512 | local.has_nulls = null_count > 0; | |
513 | this->count += valid_count; | |
514 | if (local.has_nulls && !options.skip_nulls) { | |
515 | this->state = local; | |
516 | return Status::OK(); | |
517 | } | |
518 | ||
519 | const auto true_count = arr.true_count(); | |
520 | const auto false_count = valid_count - true_count; | |
521 | local.max = true_count > 0; | |
522 | local.min = false_count == 0; | |
523 | ||
524 | this->state = local; | |
525 | return Status::OK(); | |
526 | } | |
527 | ||
528 | Status ConsumeScalar(const BooleanScalar& scalar) { | |
529 | StateType local; | |
530 | ||
531 | local.has_nulls = !scalar.is_valid; | |
532 | this->count += scalar.is_valid; | |
533 | if (local.has_nulls && !options.skip_nulls) { | |
534 | this->state = local; | |
535 | return Status::OK(); | |
536 | } | |
537 | ||
538 | const int true_count = scalar.is_valid && scalar.value; | |
539 | const int false_count = scalar.is_valid && !scalar.value; | |
540 | local.max = true_count > 0; | |
541 | local.min = false_count == 0; | |
542 | ||
543 | this->state = local; | |
544 | return Status::OK(); | |
545 | } | |
546 | }; | |
547 | ||
548 | struct NullMinMaxImpl : public ScalarAggregator { | |
549 | Status Consume(KernelContext*, const ExecBatch& batch) override { return Status::OK(); } | |
550 | ||
551 | Status MergeFrom(KernelContext*, KernelState&& src) override { return Status::OK(); } | |
552 | ||
553 | Status Finalize(KernelContext*, Datum* out) override { | |
554 | std::vector<std::shared_ptr<Scalar>> values{std::make_shared<NullScalar>(), | |
555 | std::make_shared<NullScalar>()}; | |
556 | out->value = std::make_shared<StructScalar>( | |
557 | std::move(values), struct_({field("min", null()), field("max", null())})); | |
558 | return Status::OK(); | |
559 | } | |
560 | }; | |
561 | ||
562 | template <SimdLevel::type SimdLevel> | |
563 | struct MinMaxInitState { | |
564 | std::unique_ptr<KernelState> state; | |
565 | KernelContext* ctx; | |
566 | const DataType& in_type; | |
567 | const std::shared_ptr<DataType>& out_type; | |
568 | const ScalarAggregateOptions& options; | |
569 | ||
570 | MinMaxInitState(KernelContext* ctx, const DataType& in_type, | |
571 | const std::shared_ptr<DataType>& out_type, | |
572 | const ScalarAggregateOptions& options) | |
573 | : ctx(ctx), in_type(in_type), out_type(out_type), options(options) {} | |
574 | ||
575 | Status Visit(const DataType& ty) { | |
576 | return Status::NotImplemented("No min/max implemented for ", ty); | |
577 | } | |
578 | ||
579 | Status Visit(const HalfFloatType& ty) { | |
580 | return Status::NotImplemented("No min/max implemented for ", ty); | |
581 | } | |
582 | ||
583 | Status Visit(const NullType&) { | |
584 | state.reset(new NullMinMaxImpl()); | |
585 | return Status::OK(); | |
586 | } | |
587 | ||
588 | Status Visit(const BooleanType&) { | |
589 | state.reset(new BooleanMinMaxImpl<SimdLevel>(out_type, options)); | |
590 | return Status::OK(); | |
591 | } | |
592 | ||
593 | template <typename Type> | |
594 | enable_if_physical_integer<Type, Status> Visit(const Type&) { | |
595 | using PhysicalType = typename Type::PhysicalType; | |
596 | state.reset(new MinMaxImpl<PhysicalType, SimdLevel>(out_type, options)); | |
597 | return Status::OK(); | |
598 | } | |
599 | ||
600 | template <typename Type> | |
601 | enable_if_floating_point<Type, Status> Visit(const Type&) { | |
602 | state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options)); | |
603 | return Status::OK(); | |
604 | } | |
605 | ||
606 | template <typename Type> | |
607 | enable_if_base_binary<Type, Status> Visit(const Type&) { | |
608 | state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options)); | |
609 | return Status::OK(); | |
610 | } | |
611 | ||
612 | template <typename Type> | |
613 | enable_if_fixed_size_binary<Type, Status> Visit(const Type&) { | |
614 | state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options)); | |
615 | return Status::OK(); | |
616 | } | |
617 | ||
618 | Result<std::unique_ptr<KernelState>> Create() { | |
619 | RETURN_NOT_OK(VisitTypeInline(in_type, this)); | |
620 | return std::move(state); | |
621 | } | |
622 | }; | |
623 | ||
624 | } // namespace internal | |
625 | } // namespace compute | |
626 | } // namespace arrow |