1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
23 #include "arrow/compute/api_aggregate.h"
24 #include "arrow/compute/kernels/aggregate_internal.h"
25 #include "arrow/compute/kernels/codegen_internal.h"
26 #include "arrow/compute/kernels/common.h"
27 #include "arrow/util/align_util.h"
28 #include "arrow/util/bit_block_counter.h"
29 #include "arrow/util/decimal.h"
35 void AddBasicAggKernels(KernelInit init
,
36 const std::vector
<std::shared_ptr
<DataType
>>& types
,
37 std::shared_ptr
<DataType
> out_ty
, ScalarAggregateFunction
* func
,
38 SimdLevel::type simd_level
= SimdLevel::NONE
);
40 void AddMinMaxKernels(KernelInit init
,
41 const std::vector
<std::shared_ptr
<DataType
>>& types
,
42 ScalarAggregateFunction
* func
,
43 SimdLevel::type simd_level
= SimdLevel::NONE
);
44 void AddMinMaxKernel(KernelInit init
, internal::detail::GetTypeId get_id
,
45 ScalarAggregateFunction
* func
,
46 SimdLevel::type simd_level
= SimdLevel::NONE
);
48 // SIMD variants for kernels
49 void AddSumAvx2AggKernels(ScalarAggregateFunction
* func
);
50 void AddMeanAvx2AggKernels(ScalarAggregateFunction
* func
);
51 void AddMinMaxAvx2AggKernels(ScalarAggregateFunction
* func
);
53 void AddSumAvx512AggKernels(ScalarAggregateFunction
* func
);
54 void AddMeanAvx512AggKernels(ScalarAggregateFunction
* func
);
55 void AddMinMaxAvx512AggKernels(ScalarAggregateFunction
* func
);
57 // ----------------------------------------------------------------------
60 template <typename ArrowType
, SimdLevel::type SimdLevel
>
61 struct SumImpl
: public ScalarAggregator
{
62 using ThisType
= SumImpl
<ArrowType
, SimdLevel
>;
63 using CType
= typename TypeTraits
<ArrowType
>::CType
;
64 using SumType
= typename FindAccumulatorType
<ArrowType
>::Type
;
65 using SumCType
= typename TypeTraits
<SumType
>::CType
;
66 using OutputType
= typename TypeTraits
<SumType
>::ScalarType
;
68 SumImpl(const std::shared_ptr
<DataType
>& out_type
,
69 const ScalarAggregateOptions
& options_
)
70 : out_type(out_type
), options(options_
) {}
72 Status
Consume(KernelContext
*, const ExecBatch
& batch
) override
{
73 if (batch
[0].is_array()) {
74 const auto& data
= batch
[0].array();
75 this->count
+= data
->length
- data
->GetNullCount();
76 this->nulls_observed
= this->nulls_observed
|| data
->GetNullCount();
78 if (!options
.skip_nulls
&& this->nulls_observed
) {
83 if (is_boolean_type
<ArrowType
>::value
) {
84 this->sum
+= static_cast<SumCType
>(BooleanArray(data
).true_count());
86 this->sum
+= SumArray
<CType
, SumCType
, SimdLevel
>(*data
);
89 const auto& data
= *batch
[0].scalar();
90 this->count
+= data
.is_valid
* batch
.length
;
91 this->nulls_observed
= this->nulls_observed
|| !data
.is_valid
;
93 this->sum
+= internal::UnboxScalar
<ArrowType
>::Unbox(data
) * batch
.length
;
99 Status
MergeFrom(KernelContext
*, KernelState
&& src
) override
{
100 const auto& other
= checked_cast
<const ThisType
&>(src
);
101 this->count
+= other
.count
;
102 this->sum
+= other
.sum
;
103 this->nulls_observed
= this->nulls_observed
|| other
.nulls_observed
;
107 Status
Finalize(KernelContext
*, Datum
* out
) override
{
108 if ((!options
.skip_nulls
&& this->nulls_observed
) ||
109 (this->count
< options
.min_count
)) {
110 out
->value
= std::make_shared
<OutputType
>(out_type
);
112 out
->value
= std::make_shared
<OutputType
>(this->sum
, out_type
);
118 bool nulls_observed
= false;
120 std::shared_ptr
<DataType
> out_type
;
121 ScalarAggregateOptions options
;
124 template <typename ArrowType
, SimdLevel::type SimdLevel
>
125 struct MeanImpl
: public SumImpl
<ArrowType
, SimdLevel
> {
126 using SumImpl
<ArrowType
, SimdLevel
>::SumImpl
;
128 template <typename T
= ArrowType
>
129 enable_if_decimal
<T
, Status
> FinalizeImpl(Datum
* out
) {
130 using SumCType
= typename SumImpl
<ArrowType
, SimdLevel
>::SumCType
;
131 using OutputType
= typename SumImpl
<ArrowType
, SimdLevel
>::OutputType
;
132 if ((!options
.skip_nulls
&& this->nulls_observed
) ||
133 (this->count
< options
.min_count
) || (this->count
== 0)) {
134 out
->value
= std::make_shared
<OutputType
>(this->out_type
);
136 const SumCType mean
= this->sum
/ this->count
;
137 out
->value
= std::make_shared
<OutputType
>(mean
, this->out_type
);
141 template <typename T
= ArrowType
>
142 enable_if_t
<!is_decimal_type
<T
>::value
, Status
> FinalizeImpl(Datum
* out
) {
143 if ((!options
.skip_nulls
&& this->nulls_observed
) ||
144 (this->count
< options
.min_count
)) {
145 out
->value
= std::make_shared
<DoubleScalar
>();
147 const double mean
= static_cast<double>(this->sum
) / this->count
;
148 out
->value
= std::make_shared
<DoubleScalar
>(mean
);
152 Status
Finalize(KernelContext
*, Datum
* out
) override
{ return FinalizeImpl(out
); }
154 using SumImpl
<ArrowType
, SimdLevel
>::options
;
157 template <template <typename
> class KernelClass
>
159 std::unique_ptr
<KernelState
> state
;
161 const std::shared_ptr
<DataType
> type
;
162 const ScalarAggregateOptions
& options
;
164 SumLikeInit(KernelContext
* ctx
, const std::shared_ptr
<DataType
>& type
,
165 const ScalarAggregateOptions
& options
)
166 : ctx(ctx
), type(type
), options(options
) {}
168 Status
Visit(const DataType
&) { return Status::NotImplemented("No sum implemented"); }
170 Status
Visit(const HalfFloatType
&) {
171 return Status::NotImplemented("No sum implemented");
174 Status
Visit(const BooleanType
&) {
175 auto ty
= TypeTraits
<typename KernelClass
<BooleanType
>::SumType
>::type_singleton();
176 state
.reset(new KernelClass
<BooleanType
>(ty
, options
));
180 template <typename Type
>
181 enable_if_number
<Type
, Status
> Visit(const Type
&) {
182 auto ty
= TypeTraits
<typename KernelClass
<Type
>::SumType
>::type_singleton();
183 state
.reset(new KernelClass
<Type
>(ty
, options
));
187 template <typename Type
>
188 enable_if_decimal
<Type
, Status
> Visit(const Type
&) {
189 state
.reset(new KernelClass
<Type
>(type
, options
));
193 Result
<std::unique_ptr
<KernelState
>> Create() {
194 RETURN_NOT_OK(VisitTypeInline(*type
, this));
195 return std::move(state
);
199 // ----------------------------------------------------------------------
200 // MinMax implementation
202 template <typename ArrowType
, SimdLevel::type SimdLevel
, typename Enable
= void>
203 struct MinMaxState
{};
205 template <typename ArrowType
, SimdLevel::type SimdLevel
>
206 struct MinMaxState
<ArrowType
, SimdLevel
, enable_if_boolean
<ArrowType
>> {
207 using ThisType
= MinMaxState
<ArrowType
, SimdLevel
>;
208 using T
= typename
ArrowType::c_type
;
210 ThisType
& operator+=(const ThisType
& rhs
) {
211 this->has_nulls
|= rhs
.has_nulls
;
212 this->min
= this->min
&& rhs
.min
;
213 this->max
= this->max
|| rhs
.max
;
217 void MergeOne(T value
) {
218 this->min
= this->min
&& value
;
219 this->max
= this->max
|| value
;
224 bool has_nulls
= false;
227 template <typename ArrowType
, SimdLevel::type SimdLevel
>
228 struct MinMaxState
<ArrowType
, SimdLevel
, enable_if_integer
<ArrowType
>> {
229 using ThisType
= MinMaxState
<ArrowType
, SimdLevel
>;
230 using T
= typename
ArrowType::c_type
;
231 using ScalarType
= typename TypeTraits
<ArrowType
>::ScalarType
;
233 ThisType
& operator+=(const ThisType
& rhs
) {
234 this->has_nulls
|= rhs
.has_nulls
;
235 this->min
= std::min(this->min
, rhs
.min
);
236 this->max
= std::max(this->max
, rhs
.max
);
240 void MergeOne(T value
) {
241 this->min
= std::min(this->min
, value
);
242 this->max
= std::max(this->max
, value
);
245 T min
= std::numeric_limits
<T
>::max();
246 T max
= std::numeric_limits
<T
>::min();
247 bool has_nulls
= false;
250 template <typename ArrowType
, SimdLevel::type SimdLevel
>
251 struct MinMaxState
<ArrowType
, SimdLevel
, enable_if_floating_point
<ArrowType
>> {
252 using ThisType
= MinMaxState
<ArrowType
, SimdLevel
>;
253 using T
= typename
ArrowType::c_type
;
254 using ScalarType
= typename TypeTraits
<ArrowType
>::ScalarType
;
256 ThisType
& operator+=(const ThisType
& rhs
) {
257 this->has_nulls
|= rhs
.has_nulls
;
258 this->min
= std::fmin(this->min
, rhs
.min
);
259 this->max
= std::fmax(this->max
, rhs
.max
);
263 void MergeOne(T value
) {
264 this->min
= std::fmin(this->min
, value
);
265 this->max
= std::fmax(this->max
, value
);
268 T min
= std::numeric_limits
<T
>::infinity();
269 T max
= -std::numeric_limits
<T
>::infinity();
270 bool has_nulls
= false;
273 template <typename ArrowType
, SimdLevel::type SimdLevel
>
274 struct MinMaxState
<ArrowType
, SimdLevel
, enable_if_decimal
<ArrowType
>> {
275 using ThisType
= MinMaxState
<ArrowType
, SimdLevel
>;
276 using T
= typename TypeTraits
<ArrowType
>::CType
;
277 using ScalarType
= typename TypeTraits
<ArrowType
>::ScalarType
;
279 MinMaxState() : min(T::GetMaxSentinel()), max(T::GetMinSentinel()) {}
281 ThisType
& operator+=(const ThisType
& rhs
) {
282 this->has_nulls
|= rhs
.has_nulls
;
283 this->min
= std::min(this->min
, rhs
.min
);
284 this->max
= std::max(this->max
, rhs
.max
);
288 void MergeOne(util::string_view value
) {
289 MergeOne(T(reinterpret_cast<const uint8_t*>(value
.data())));
292 void MergeOne(const T value
) {
293 this->min
= std::min(this->min
, value
);
294 this->max
= std::max(this->max
, value
);
299 bool has_nulls
= false;
302 template <typename ArrowType
, SimdLevel::type SimdLevel
>
303 struct MinMaxState
<ArrowType
, SimdLevel
,
304 enable_if_t
<is_base_binary_type
<ArrowType
>::value
||
305 std::is_same
<ArrowType
, FixedSizeBinaryType
>::value
>> {
306 using ThisType
= MinMaxState
<ArrowType
, SimdLevel
>;
307 using ScalarType
= typename TypeTraits
<ArrowType
>::ScalarType
;
309 ThisType
& operator+=(const ThisType
& rhs
) {
310 if (!this->seen
&& rhs
.seen
) {
313 } else if (this->seen
&& rhs
.seen
) {
314 if (this->min
> rhs
.min
) {
317 if (this->max
< rhs
.max
) {
321 this->has_nulls
|= rhs
.has_nulls
;
322 this->seen
|= rhs
.seen
;
326 void MergeOne(util::string_view value
) {
328 this->min
= std::string(value
);
329 this->max
= std::string(value
);
331 if (value
< util::string_view(this->min
)) {
332 this->min
= std::string(value
);
333 } else if (value
> util::string_view(this->max
)) {
334 this->max
= std::string(value
);
342 bool has_nulls
= false;
346 template <typename ArrowType
, SimdLevel::type SimdLevel
>
347 struct MinMaxImpl
: public ScalarAggregator
{
348 using ArrayType
= typename TypeTraits
<ArrowType
>::ArrayType
;
349 using ThisType
= MinMaxImpl
<ArrowType
, SimdLevel
>;
350 using StateType
= MinMaxState
<ArrowType
, SimdLevel
>;
352 MinMaxImpl(std::shared_ptr
<DataType
> out_type
, ScalarAggregateOptions options
)
353 : out_type(std::move(out_type
)), options(std::move(options
)), count(0) {
354 this->options
.min_count
= std::max
<uint32_t>(1, this->options
.min_count
);
357 Status
Consume(KernelContext
*, const ExecBatch
& batch
) override
{
358 if (batch
[0].is_array()) {
359 return ConsumeArray(ArrayType(batch
[0].array()));
361 return ConsumeScalar(*batch
[0].scalar());
364 Status
ConsumeScalar(const Scalar
& scalar
) {
366 local
.has_nulls
= !scalar
.is_valid
;
367 this->count
+= scalar
.is_valid
;
369 if (local
.has_nulls
&& !options
.skip_nulls
) {
374 local
.MergeOne(internal::UnboxScalar
<ArrowType
>::Unbox(scalar
));
379 Status
ConsumeArray(const ArrayType
& arr
) {
382 const auto null_count
= arr
.null_count();
383 local
.has_nulls
= null_count
> 0;
384 this->count
+= arr
.length() - null_count
;
386 if (local
.has_nulls
&& !options
.skip_nulls
) {
391 if (local
.has_nulls
) {
392 local
+= ConsumeWithNulls(arr
);
393 } else { // All true values
394 for (int64_t i
= 0; i
< arr
.length(); i
++) {
395 local
.MergeOne(arr
.GetView(i
));
402 Status
MergeFrom(KernelContext
*, KernelState
&& src
) override
{
403 const auto& other
= checked_cast
<const ThisType
&>(src
);
404 this->state
+= other
.state
;
405 this->count
+= other
.count
;
409 Status
Finalize(KernelContext
*, Datum
* out
) override
{
410 const auto& struct_type
= checked_cast
<const StructType
&>(*out_type
);
411 const auto& child_type
= struct_type
.field(0)->type();
413 std::vector
<std::shared_ptr
<Scalar
>> values
;
414 // Physical type != result type
415 if ((state
.has_nulls
&& !options
.skip_nulls
) || (this->count
< options
.min_count
)) {
417 auto null_scalar
= MakeNullScalar(child_type
);
418 values
= {null_scalar
, null_scalar
};
420 ARROW_ASSIGN_OR_RAISE(auto min_scalar
,
421 MakeScalar(child_type
, std::move(state
.min
)));
422 ARROW_ASSIGN_OR_RAISE(auto max_scalar
,
423 MakeScalar(child_type
, std::move(state
.max
)));
424 values
= {std::move(min_scalar
), std::move(max_scalar
)};
426 out
->value
= std::make_shared
<StructScalar
>(std::move(values
), this->out_type
);
430 std::shared_ptr
<DataType
> out_type
;
431 ScalarAggregateOptions options
;
433 MinMaxState
<ArrowType
, SimdLevel
> state
;
436 StateType
ConsumeWithNulls(const ArrayType
& arr
) const {
438 const int64_t length
= arr
.length();
439 int64_t offset
= arr
.offset();
440 const uint8_t* bitmap
= arr
.null_bitmap_data();
443 const auto p
= arrow::internal::BitmapWordAlign
<1>(bitmap
, offset
, length
);
444 // First handle the leading bits
445 const int64_t leading_bits
= p
.leading_bits
;
446 while (idx
< leading_bits
) {
447 if (BitUtil::GetBit(bitmap
, offset
)) {
448 local
.MergeOne(arr
.GetView(idx
));
454 // The aligned parts scanned with BitBlockCounter
455 arrow::internal::BitBlockCounter
data_counter(bitmap
, offset
, length
- leading_bits
);
456 auto current_block
= data_counter
.NextWord();
457 while (idx
< length
) {
458 if (current_block
.AllSet()) { // All true values
460 // Scan forward until a block that has some false values (or the end)
461 while (current_block
.length
> 0 && current_block
.AllSet()) {
462 run_length
+= current_block
.length
;
463 current_block
= data_counter
.NextWord();
465 for (int64_t i
= 0; i
< run_length
; i
++) {
466 local
.MergeOne(arr
.GetView(idx
+ i
));
469 offset
+= run_length
;
470 // The current_block already computed, advance to next loop
472 } else if (!current_block
.NoneSet()) { // Some values are null
473 BitmapReader
reader(arr
.null_bitmap_data(), offset
, current_block
.length
);
474 for (int64_t i
= 0; i
< current_block
.length
; i
++) {
475 if (reader
.IsSet()) {
476 local
.MergeOne(arr
.GetView(idx
+ i
));
481 idx
+= current_block
.length
;
482 offset
+= current_block
.length
;
483 } else { // All null values
484 idx
+= current_block
.length
;
485 offset
+= current_block
.length
;
487 current_block
= data_counter
.NextWord();
494 template <SimdLevel::type SimdLevel
>
495 struct BooleanMinMaxImpl
: public MinMaxImpl
<BooleanType
, SimdLevel
> {
496 using StateType
= MinMaxState
<BooleanType
, SimdLevel
>;
497 using ArrayType
= typename TypeTraits
<BooleanType
>::ArrayType
;
498 using MinMaxImpl
<BooleanType
, SimdLevel
>::MinMaxImpl
;
499 using MinMaxImpl
<BooleanType
, SimdLevel
>::options
;
501 Status
Consume(KernelContext
*, const ExecBatch
& batch
) override
{
502 if (ARROW_PREDICT_FALSE(batch
[0].is_scalar())) {
503 return ConsumeScalar(checked_cast
<const BooleanScalar
&>(*batch
[0].scalar()));
506 ArrayType
arr(batch
[0].array());
508 const auto arr_length
= arr
.length();
509 const auto null_count
= arr
.null_count();
510 const auto valid_count
= arr_length
- null_count
;
512 local
.has_nulls
= null_count
> 0;
513 this->count
+= valid_count
;
514 if (local
.has_nulls
&& !options
.skip_nulls
) {
519 const auto true_count
= arr
.true_count();
520 const auto false_count
= valid_count
- true_count
;
521 local
.max
= true_count
> 0;
522 local
.min
= false_count
== 0;
528 Status
ConsumeScalar(const BooleanScalar
& scalar
) {
531 local
.has_nulls
= !scalar
.is_valid
;
532 this->count
+= scalar
.is_valid
;
533 if (local
.has_nulls
&& !options
.skip_nulls
) {
538 const int true_count
= scalar
.is_valid
&& scalar
.value
;
539 const int false_count
= scalar
.is_valid
&& !scalar
.value
;
540 local
.max
= true_count
> 0;
541 local
.min
= false_count
== 0;
548 struct NullMinMaxImpl
: public ScalarAggregator
{
549 Status
Consume(KernelContext
*, const ExecBatch
& batch
) override
{ return Status::OK(); }
551 Status
MergeFrom(KernelContext
*, KernelState
&& src
) override
{ return Status::OK(); }
553 Status
Finalize(KernelContext
*, Datum
* out
) override
{
554 std::vector
<std::shared_ptr
<Scalar
>> values
{std::make_shared
<NullScalar
>(),
555 std::make_shared
<NullScalar
>()};
556 out
->value
= std::make_shared
<StructScalar
>(
557 std::move(values
), struct_({field("min", null()), field("max", null())}));
562 template <SimdLevel::type SimdLevel
>
563 struct MinMaxInitState
{
564 std::unique_ptr
<KernelState
> state
;
566 const DataType
& in_type
;
567 const std::shared_ptr
<DataType
>& out_type
;
568 const ScalarAggregateOptions
& options
;
570 MinMaxInitState(KernelContext
* ctx
, const DataType
& in_type
,
571 const std::shared_ptr
<DataType
>& out_type
,
572 const ScalarAggregateOptions
& options
)
573 : ctx(ctx
), in_type(in_type
), out_type(out_type
), options(options
) {}
575 Status
Visit(const DataType
& ty
) {
576 return Status::NotImplemented("No min/max implemented for ", ty
);
579 Status
Visit(const HalfFloatType
& ty
) {
580 return Status::NotImplemented("No min/max implemented for ", ty
);
583 Status
Visit(const NullType
&) {
584 state
.reset(new NullMinMaxImpl());
588 Status
Visit(const BooleanType
&) {
589 state
.reset(new BooleanMinMaxImpl
<SimdLevel
>(out_type
, options
));
593 template <typename Type
>
594 enable_if_physical_integer
<Type
, Status
> Visit(const Type
&) {
595 using PhysicalType
= typename
Type::PhysicalType
;
596 state
.reset(new MinMaxImpl
<PhysicalType
, SimdLevel
>(out_type
, options
));
600 template <typename Type
>
601 enable_if_floating_point
<Type
, Status
> Visit(const Type
&) {
602 state
.reset(new MinMaxImpl
<Type
, SimdLevel
>(out_type
, options
));
606 template <typename Type
>
607 enable_if_base_binary
<Type
, Status
> Visit(const Type
&) {
608 state
.reset(new MinMaxImpl
<Type
, SimdLevel
>(out_type
, options
));
612 template <typename Type
>
613 enable_if_fixed_size_binary
<Type
, Status
> Visit(const Type
&) {
614 state
.reset(new MinMaxImpl
<Type
, SimdLevel
>(out_type
, options
));
618 Result
<std::unique_ptr
<KernelState
>> Create() {
619 RETURN_NOT_OK(VisitTypeInline(in_type
, this));
620 return std::move(state
);
624 } // namespace internal
625 } // namespace compute