]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / kernels / aggregate_basic_internal.h
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #pragma once
19
20 #include <cmath>
21 #include <utility>
22
23 #include "arrow/compute/api_aggregate.h"
24 #include "arrow/compute/kernels/aggregate_internal.h"
25 #include "arrow/compute/kernels/codegen_internal.h"
26 #include "arrow/compute/kernels/common.h"
27 #include "arrow/util/align_util.h"
28 #include "arrow/util/bit_block_counter.h"
29 #include "arrow/util/decimal.h"
30
31 namespace arrow {
32 namespace compute {
33 namespace internal {
34
35 void AddBasicAggKernels(KernelInit init,
36 const std::vector<std::shared_ptr<DataType>>& types,
37 std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func,
38 SimdLevel::type simd_level = SimdLevel::NONE);
39
40 void AddMinMaxKernels(KernelInit init,
41 const std::vector<std::shared_ptr<DataType>>& types,
42 ScalarAggregateFunction* func,
43 SimdLevel::type simd_level = SimdLevel::NONE);
44 void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id,
45 ScalarAggregateFunction* func,
46 SimdLevel::type simd_level = SimdLevel::NONE);
47
48 // SIMD variants for kernels
49 void AddSumAvx2AggKernels(ScalarAggregateFunction* func);
50 void AddMeanAvx2AggKernels(ScalarAggregateFunction* func);
51 void AddMinMaxAvx2AggKernels(ScalarAggregateFunction* func);
52
53 void AddSumAvx512AggKernels(ScalarAggregateFunction* func);
54 void AddMeanAvx512AggKernels(ScalarAggregateFunction* func);
55 void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func);
56
57 // ----------------------------------------------------------------------
58 // Sum implementation
59
60 template <typename ArrowType, SimdLevel::type SimdLevel>
61 struct SumImpl : public ScalarAggregator {
62 using ThisType = SumImpl<ArrowType, SimdLevel>;
63 using CType = typename TypeTraits<ArrowType>::CType;
64 using SumType = typename FindAccumulatorType<ArrowType>::Type;
65 using SumCType = typename TypeTraits<SumType>::CType;
66 using OutputType = typename TypeTraits<SumType>::ScalarType;
67
68 SumImpl(const std::shared_ptr<DataType>& out_type,
69 const ScalarAggregateOptions& options_)
70 : out_type(out_type), options(options_) {}
71
72 Status Consume(KernelContext*, const ExecBatch& batch) override {
73 if (batch[0].is_array()) {
74 const auto& data = batch[0].array();
75 this->count += data->length - data->GetNullCount();
76 this->nulls_observed = this->nulls_observed || data->GetNullCount();
77
78 if (!options.skip_nulls && this->nulls_observed) {
79 // Short-circuit
80 return Status::OK();
81 }
82
83 if (is_boolean_type<ArrowType>::value) {
84 this->sum += static_cast<SumCType>(BooleanArray(data).true_count());
85 } else {
86 this->sum += SumArray<CType, SumCType, SimdLevel>(*data);
87 }
88 } else {
89 const auto& data = *batch[0].scalar();
90 this->count += data.is_valid * batch.length;
91 this->nulls_observed = this->nulls_observed || !data.is_valid;
92 if (data.is_valid) {
93 this->sum += internal::UnboxScalar<ArrowType>::Unbox(data) * batch.length;
94 }
95 }
96 return Status::OK();
97 }
98
99 Status MergeFrom(KernelContext*, KernelState&& src) override {
100 const auto& other = checked_cast<const ThisType&>(src);
101 this->count += other.count;
102 this->sum += other.sum;
103 this->nulls_observed = this->nulls_observed || other.nulls_observed;
104 return Status::OK();
105 }
106
107 Status Finalize(KernelContext*, Datum* out) override {
108 if ((!options.skip_nulls && this->nulls_observed) ||
109 (this->count < options.min_count)) {
110 out->value = std::make_shared<OutputType>(out_type);
111 } else {
112 out->value = std::make_shared<OutputType>(this->sum, out_type);
113 }
114 return Status::OK();
115 }
116
117 size_t count = 0;
118 bool nulls_observed = false;
119 SumCType sum = 0;
120 std::shared_ptr<DataType> out_type;
121 ScalarAggregateOptions options;
122 };
123
124 template <typename ArrowType, SimdLevel::type SimdLevel>
125 struct MeanImpl : public SumImpl<ArrowType, SimdLevel> {
126 using SumImpl<ArrowType, SimdLevel>::SumImpl;
127
128 template <typename T = ArrowType>
129 enable_if_decimal<T, Status> FinalizeImpl(Datum* out) {
130 using SumCType = typename SumImpl<ArrowType, SimdLevel>::SumCType;
131 using OutputType = typename SumImpl<ArrowType, SimdLevel>::OutputType;
132 if ((!options.skip_nulls && this->nulls_observed) ||
133 (this->count < options.min_count) || (this->count == 0)) {
134 out->value = std::make_shared<OutputType>(this->out_type);
135 } else {
136 const SumCType mean = this->sum / this->count;
137 out->value = std::make_shared<OutputType>(mean, this->out_type);
138 }
139 return Status::OK();
140 }
141 template <typename T = ArrowType>
142 enable_if_t<!is_decimal_type<T>::value, Status> FinalizeImpl(Datum* out) {
143 if ((!options.skip_nulls && this->nulls_observed) ||
144 (this->count < options.min_count)) {
145 out->value = std::make_shared<DoubleScalar>();
146 } else {
147 const double mean = static_cast<double>(this->sum) / this->count;
148 out->value = std::make_shared<DoubleScalar>(mean);
149 }
150 return Status::OK();
151 }
152 Status Finalize(KernelContext*, Datum* out) override { return FinalizeImpl(out); }
153
154 using SumImpl<ArrowType, SimdLevel>::options;
155 };
156
157 template <template <typename> class KernelClass>
158 struct SumLikeInit {
159 std::unique_ptr<KernelState> state;
160 KernelContext* ctx;
161 const std::shared_ptr<DataType> type;
162 const ScalarAggregateOptions& options;
163
164 SumLikeInit(KernelContext* ctx, const std::shared_ptr<DataType>& type,
165 const ScalarAggregateOptions& options)
166 : ctx(ctx), type(type), options(options) {}
167
168 Status Visit(const DataType&) { return Status::NotImplemented("No sum implemented"); }
169
170 Status Visit(const HalfFloatType&) {
171 return Status::NotImplemented("No sum implemented");
172 }
173
174 Status Visit(const BooleanType&) {
175 auto ty = TypeTraits<typename KernelClass<BooleanType>::SumType>::type_singleton();
176 state.reset(new KernelClass<BooleanType>(ty, options));
177 return Status::OK();
178 }
179
180 template <typename Type>
181 enable_if_number<Type, Status> Visit(const Type&) {
182 auto ty = TypeTraits<typename KernelClass<Type>::SumType>::type_singleton();
183 state.reset(new KernelClass<Type>(ty, options));
184 return Status::OK();
185 }
186
187 template <typename Type>
188 enable_if_decimal<Type, Status> Visit(const Type&) {
189 state.reset(new KernelClass<Type>(type, options));
190 return Status::OK();
191 }
192
193 Result<std::unique_ptr<KernelState>> Create() {
194 RETURN_NOT_OK(VisitTypeInline(*type, this));
195 return std::move(state);
196 }
197 };
198
199 // ----------------------------------------------------------------------
200 // MinMax implementation
201
202 template <typename ArrowType, SimdLevel::type SimdLevel, typename Enable = void>
203 struct MinMaxState {};
204
205 template <typename ArrowType, SimdLevel::type SimdLevel>
206 struct MinMaxState<ArrowType, SimdLevel, enable_if_boolean<ArrowType>> {
207 using ThisType = MinMaxState<ArrowType, SimdLevel>;
208 using T = typename ArrowType::c_type;
209
210 ThisType& operator+=(const ThisType& rhs) {
211 this->has_nulls |= rhs.has_nulls;
212 this->min = this->min && rhs.min;
213 this->max = this->max || rhs.max;
214 return *this;
215 }
216
217 void MergeOne(T value) {
218 this->min = this->min && value;
219 this->max = this->max || value;
220 }
221
222 T min = true;
223 T max = false;
224 bool has_nulls = false;
225 };
226
227 template <typename ArrowType, SimdLevel::type SimdLevel>
228 struct MinMaxState<ArrowType, SimdLevel, enable_if_integer<ArrowType>> {
229 using ThisType = MinMaxState<ArrowType, SimdLevel>;
230 using T = typename ArrowType::c_type;
231 using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
232
233 ThisType& operator+=(const ThisType& rhs) {
234 this->has_nulls |= rhs.has_nulls;
235 this->min = std::min(this->min, rhs.min);
236 this->max = std::max(this->max, rhs.max);
237 return *this;
238 }
239
240 void MergeOne(T value) {
241 this->min = std::min(this->min, value);
242 this->max = std::max(this->max, value);
243 }
244
245 T min = std::numeric_limits<T>::max();
246 T max = std::numeric_limits<T>::min();
247 bool has_nulls = false;
248 };
249
250 template <typename ArrowType, SimdLevel::type SimdLevel>
251 struct MinMaxState<ArrowType, SimdLevel, enable_if_floating_point<ArrowType>> {
252 using ThisType = MinMaxState<ArrowType, SimdLevel>;
253 using T = typename ArrowType::c_type;
254 using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
255
256 ThisType& operator+=(const ThisType& rhs) {
257 this->has_nulls |= rhs.has_nulls;
258 this->min = std::fmin(this->min, rhs.min);
259 this->max = std::fmax(this->max, rhs.max);
260 return *this;
261 }
262
263 void MergeOne(T value) {
264 this->min = std::fmin(this->min, value);
265 this->max = std::fmax(this->max, value);
266 }
267
268 T min = std::numeric_limits<T>::infinity();
269 T max = -std::numeric_limits<T>::infinity();
270 bool has_nulls = false;
271 };
272
273 template <typename ArrowType, SimdLevel::type SimdLevel>
274 struct MinMaxState<ArrowType, SimdLevel, enable_if_decimal<ArrowType>> {
275 using ThisType = MinMaxState<ArrowType, SimdLevel>;
276 using T = typename TypeTraits<ArrowType>::CType;
277 using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
278
279 MinMaxState() : min(T::GetMaxSentinel()), max(T::GetMinSentinel()) {}
280
281 ThisType& operator+=(const ThisType& rhs) {
282 this->has_nulls |= rhs.has_nulls;
283 this->min = std::min(this->min, rhs.min);
284 this->max = std::max(this->max, rhs.max);
285 return *this;
286 }
287
288 void MergeOne(util::string_view value) {
289 MergeOne(T(reinterpret_cast<const uint8_t*>(value.data())));
290 }
291
292 void MergeOne(const T value) {
293 this->min = std::min(this->min, value);
294 this->max = std::max(this->max, value);
295 }
296
297 T min;
298 T max;
299 bool has_nulls = false;
300 };
301
302 template <typename ArrowType, SimdLevel::type SimdLevel>
303 struct MinMaxState<ArrowType, SimdLevel,
304 enable_if_t<is_base_binary_type<ArrowType>::value ||
305 std::is_same<ArrowType, FixedSizeBinaryType>::value>> {
306 using ThisType = MinMaxState<ArrowType, SimdLevel>;
307 using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
308
309 ThisType& operator+=(const ThisType& rhs) {
310 if (!this->seen && rhs.seen) {
311 this->min = rhs.min;
312 this->max = rhs.max;
313 } else if (this->seen && rhs.seen) {
314 if (this->min > rhs.min) {
315 this->min = rhs.min;
316 }
317 if (this->max < rhs.max) {
318 this->max = rhs.max;
319 }
320 }
321 this->has_nulls |= rhs.has_nulls;
322 this->seen |= rhs.seen;
323 return *this;
324 }
325
326 void MergeOne(util::string_view value) {
327 if (!seen) {
328 this->min = std::string(value);
329 this->max = std::string(value);
330 } else {
331 if (value < util::string_view(this->min)) {
332 this->min = std::string(value);
333 } else if (value > util::string_view(this->max)) {
334 this->max = std::string(value);
335 }
336 }
337 this->seen = true;
338 }
339
340 std::string min;
341 std::string max;
342 bool has_nulls = false;
343 bool seen = false;
344 };
345
346 template <typename ArrowType, SimdLevel::type SimdLevel>
347 struct MinMaxImpl : public ScalarAggregator {
348 using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
349 using ThisType = MinMaxImpl<ArrowType, SimdLevel>;
350 using StateType = MinMaxState<ArrowType, SimdLevel>;
351
352 MinMaxImpl(std::shared_ptr<DataType> out_type, ScalarAggregateOptions options)
353 : out_type(std::move(out_type)), options(std::move(options)), count(0) {
354 this->options.min_count = std::max<uint32_t>(1, this->options.min_count);
355 }
356
357 Status Consume(KernelContext*, const ExecBatch& batch) override {
358 if (batch[0].is_array()) {
359 return ConsumeArray(ArrayType(batch[0].array()));
360 }
361 return ConsumeScalar(*batch[0].scalar());
362 }
363
364 Status ConsumeScalar(const Scalar& scalar) {
365 StateType local;
366 local.has_nulls = !scalar.is_valid;
367 this->count += scalar.is_valid;
368
369 if (local.has_nulls && !options.skip_nulls) {
370 this->state = local;
371 return Status::OK();
372 }
373
374 local.MergeOne(internal::UnboxScalar<ArrowType>::Unbox(scalar));
375 this->state = local;
376 return Status::OK();
377 }
378
379 Status ConsumeArray(const ArrayType& arr) {
380 StateType local;
381
382 const auto null_count = arr.null_count();
383 local.has_nulls = null_count > 0;
384 this->count += arr.length() - null_count;
385
386 if (local.has_nulls && !options.skip_nulls) {
387 this->state = local;
388 return Status::OK();
389 }
390
391 if (local.has_nulls) {
392 local += ConsumeWithNulls(arr);
393 } else { // All true values
394 for (int64_t i = 0; i < arr.length(); i++) {
395 local.MergeOne(arr.GetView(i));
396 }
397 }
398 this->state = local;
399 return Status::OK();
400 }
401
402 Status MergeFrom(KernelContext*, KernelState&& src) override {
403 const auto& other = checked_cast<const ThisType&>(src);
404 this->state += other.state;
405 this->count += other.count;
406 return Status::OK();
407 }
408
409 Status Finalize(KernelContext*, Datum* out) override {
410 const auto& struct_type = checked_cast<const StructType&>(*out_type);
411 const auto& child_type = struct_type.field(0)->type();
412
413 std::vector<std::shared_ptr<Scalar>> values;
414 // Physical type != result type
415 if ((state.has_nulls && !options.skip_nulls) || (this->count < options.min_count)) {
416 // (null, null)
417 auto null_scalar = MakeNullScalar(child_type);
418 values = {null_scalar, null_scalar};
419 } else {
420 ARROW_ASSIGN_OR_RAISE(auto min_scalar,
421 MakeScalar(child_type, std::move(state.min)));
422 ARROW_ASSIGN_OR_RAISE(auto max_scalar,
423 MakeScalar(child_type, std::move(state.max)));
424 values = {std::move(min_scalar), std::move(max_scalar)};
425 }
426 out->value = std::make_shared<StructScalar>(std::move(values), this->out_type);
427 return Status::OK();
428 }
429
430 std::shared_ptr<DataType> out_type;
431 ScalarAggregateOptions options;
432 int64_t count;
433 MinMaxState<ArrowType, SimdLevel> state;
434
435 private:
436 StateType ConsumeWithNulls(const ArrayType& arr) const {
437 StateType local;
438 const int64_t length = arr.length();
439 int64_t offset = arr.offset();
440 const uint8_t* bitmap = arr.null_bitmap_data();
441 int64_t idx = 0;
442
443 const auto p = arrow::internal::BitmapWordAlign<1>(bitmap, offset, length);
444 // First handle the leading bits
445 const int64_t leading_bits = p.leading_bits;
446 while (idx < leading_bits) {
447 if (BitUtil::GetBit(bitmap, offset)) {
448 local.MergeOne(arr.GetView(idx));
449 }
450 idx++;
451 offset++;
452 }
453
454 // The aligned parts scanned with BitBlockCounter
455 arrow::internal::BitBlockCounter data_counter(bitmap, offset, length - leading_bits);
456 auto current_block = data_counter.NextWord();
457 while (idx < length) {
458 if (current_block.AllSet()) { // All true values
459 int run_length = 0;
460 // Scan forward until a block that has some false values (or the end)
461 while (current_block.length > 0 && current_block.AllSet()) {
462 run_length += current_block.length;
463 current_block = data_counter.NextWord();
464 }
465 for (int64_t i = 0; i < run_length; i++) {
466 local.MergeOne(arr.GetView(idx + i));
467 }
468 idx += run_length;
469 offset += run_length;
470 // The current_block already computed, advance to next loop
471 continue;
472 } else if (!current_block.NoneSet()) { // Some values are null
473 BitmapReader reader(arr.null_bitmap_data(), offset, current_block.length);
474 for (int64_t i = 0; i < current_block.length; i++) {
475 if (reader.IsSet()) {
476 local.MergeOne(arr.GetView(idx + i));
477 }
478 reader.Next();
479 }
480
481 idx += current_block.length;
482 offset += current_block.length;
483 } else { // All null values
484 idx += current_block.length;
485 offset += current_block.length;
486 }
487 current_block = data_counter.NextWord();
488 }
489
490 return local;
491 }
492 };
493
494 template <SimdLevel::type SimdLevel>
495 struct BooleanMinMaxImpl : public MinMaxImpl<BooleanType, SimdLevel> {
496 using StateType = MinMaxState<BooleanType, SimdLevel>;
497 using ArrayType = typename TypeTraits<BooleanType>::ArrayType;
498 using MinMaxImpl<BooleanType, SimdLevel>::MinMaxImpl;
499 using MinMaxImpl<BooleanType, SimdLevel>::options;
500
501 Status Consume(KernelContext*, const ExecBatch& batch) override {
502 if (ARROW_PREDICT_FALSE(batch[0].is_scalar())) {
503 return ConsumeScalar(checked_cast<const BooleanScalar&>(*batch[0].scalar()));
504 }
505 StateType local;
506 ArrayType arr(batch[0].array());
507
508 const auto arr_length = arr.length();
509 const auto null_count = arr.null_count();
510 const auto valid_count = arr_length - null_count;
511
512 local.has_nulls = null_count > 0;
513 this->count += valid_count;
514 if (local.has_nulls && !options.skip_nulls) {
515 this->state = local;
516 return Status::OK();
517 }
518
519 const auto true_count = arr.true_count();
520 const auto false_count = valid_count - true_count;
521 local.max = true_count > 0;
522 local.min = false_count == 0;
523
524 this->state = local;
525 return Status::OK();
526 }
527
528 Status ConsumeScalar(const BooleanScalar& scalar) {
529 StateType local;
530
531 local.has_nulls = !scalar.is_valid;
532 this->count += scalar.is_valid;
533 if (local.has_nulls && !options.skip_nulls) {
534 this->state = local;
535 return Status::OK();
536 }
537
538 const int true_count = scalar.is_valid && scalar.value;
539 const int false_count = scalar.is_valid && !scalar.value;
540 local.max = true_count > 0;
541 local.min = false_count == 0;
542
543 this->state = local;
544 return Status::OK();
545 }
546 };
547
548 struct NullMinMaxImpl : public ScalarAggregator {
549 Status Consume(KernelContext*, const ExecBatch& batch) override { return Status::OK(); }
550
551 Status MergeFrom(KernelContext*, KernelState&& src) override { return Status::OK(); }
552
553 Status Finalize(KernelContext*, Datum* out) override {
554 std::vector<std::shared_ptr<Scalar>> values{std::make_shared<NullScalar>(),
555 std::make_shared<NullScalar>()};
556 out->value = std::make_shared<StructScalar>(
557 std::move(values), struct_({field("min", null()), field("max", null())}));
558 return Status::OK();
559 }
560 };
561
562 template <SimdLevel::type SimdLevel>
563 struct MinMaxInitState {
564 std::unique_ptr<KernelState> state;
565 KernelContext* ctx;
566 const DataType& in_type;
567 const std::shared_ptr<DataType>& out_type;
568 const ScalarAggregateOptions& options;
569
570 MinMaxInitState(KernelContext* ctx, const DataType& in_type,
571 const std::shared_ptr<DataType>& out_type,
572 const ScalarAggregateOptions& options)
573 : ctx(ctx), in_type(in_type), out_type(out_type), options(options) {}
574
575 Status Visit(const DataType& ty) {
576 return Status::NotImplemented("No min/max implemented for ", ty);
577 }
578
579 Status Visit(const HalfFloatType& ty) {
580 return Status::NotImplemented("No min/max implemented for ", ty);
581 }
582
583 Status Visit(const NullType&) {
584 state.reset(new NullMinMaxImpl());
585 return Status::OK();
586 }
587
588 Status Visit(const BooleanType&) {
589 state.reset(new BooleanMinMaxImpl<SimdLevel>(out_type, options));
590 return Status::OK();
591 }
592
593 template <typename Type>
594 enable_if_physical_integer<Type, Status> Visit(const Type&) {
595 using PhysicalType = typename Type::PhysicalType;
596 state.reset(new MinMaxImpl<PhysicalType, SimdLevel>(out_type, options));
597 return Status::OK();
598 }
599
600 template <typename Type>
601 enable_if_floating_point<Type, Status> Visit(const Type&) {
602 state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options));
603 return Status::OK();
604 }
605
606 template <typename Type>
607 enable_if_base_binary<Type, Status> Visit(const Type&) {
608 state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options));
609 return Status::OK();
610 }
611
612 template <typename Type>
613 enable_if_fixed_size_binary<Type, Status> Visit(const Type&) {
614 state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options));
615 return Status::OK();
616 }
617
618 Result<std::unique_ptr<KernelState>> Create() {
619 RETURN_NOT_OK(VisitTypeInline(in_type, this));
620 return std::move(state);
621 }
622 };
623
624 } // namespace internal
625 } // namespace compute
626 } // namespace arrow