1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include "arrow/compute/api_aggregate.h"
19 #include "arrow/compute/kernels/aggregate_basic_internal.h"
20 #include "arrow/compute/kernels/aggregate_internal.h"
21 #include "arrow/compute/kernels/common.h"
22 #include "arrow/compute/kernels/util_internal.h"
23 #include "arrow/util/cpu_info.h"
24 #include "arrow/util/hashing.h"
25 #include "arrow/util/make_unique.h"
33 Status
AggregateConsume(KernelContext
* ctx
, const ExecBatch
& batch
) {
34 return checked_cast
<ScalarAggregator
*>(ctx
->state())->Consume(ctx
, batch
);
37 Status
AggregateMerge(KernelContext
* ctx
, KernelState
&& src
, KernelState
* dst
) {
38 return checked_cast
<ScalarAggregator
*>(dst
)->MergeFrom(ctx
, std::move(src
));
41 Status
AggregateFinalize(KernelContext
* ctx
, Datum
* out
) {
42 return checked_cast
<ScalarAggregator
*>(ctx
->state())->Finalize(ctx
, out
);
47 void AddAggKernel(std::shared_ptr
<KernelSignature
> sig
, KernelInit init
,
48 ScalarAggregateFunction
* func
, SimdLevel::type simd_level
) {
49 ScalarAggregateKernel
kernel(std::move(sig
), std::move(init
), AggregateConsume
,
50 AggregateMerge
, AggregateFinalize
);
52 kernel
.simd_level
= simd_level
;
53 DCHECK_OK(func
->AddKernel(std::move(kernel
)));
56 void AddAggKernel(std::shared_ptr
<KernelSignature
> sig
, KernelInit init
,
57 ScalarAggregateFinalize finalize
, ScalarAggregateFunction
* func
,
58 SimdLevel::type simd_level
) {
59 ScalarAggregateKernel
kernel(std::move(sig
), std::move(init
), AggregateConsume
,
60 AggregateMerge
, std::move(finalize
));
62 kernel
.simd_level
= simd_level
;
63 DCHECK_OK(func
->AddKernel(std::move(kernel
)));
68 // ----------------------------------------------------------------------
69 // Count implementation
71 struct CountImpl
: public ScalarAggregator
{
72 explicit CountImpl(CountOptions options
) : options(std::move(options
)) {}
74 Status
Consume(KernelContext
*, const ExecBatch
& batch
) override
{
75 if (options
.mode
== CountOptions::ALL
) {
76 this->non_nulls
+= batch
.length
;
77 } else if (batch
[0].is_array()) {
78 const ArrayData
& input
= *batch
[0].array();
79 const int64_t nulls
= input
.GetNullCount();
81 this->non_nulls
+= input
.length
- nulls
;
83 const Scalar
& input
= *batch
[0].scalar();
84 this->nulls
+= !input
.is_valid
* batch
.length
;
85 this->non_nulls
+= input
.is_valid
* batch
.length
;
90 Status
MergeFrom(KernelContext
*, KernelState
&& src
) override
{
91 const auto& other_state
= checked_cast
<const CountImpl
&>(src
);
92 this->non_nulls
+= other_state
.non_nulls
;
93 this->nulls
+= other_state
.nulls
;
97 Status
Finalize(KernelContext
* ctx
, Datum
* out
) override
{
98 const auto& state
= checked_cast
<const CountImpl
&>(*ctx
->state());
99 switch (state
.options
.mode
) {
100 case CountOptions::ONLY_VALID
:
101 case CountOptions::ALL
:
102 // ALL is equivalent since we don't count the null/non-null
103 // separately to avoid potentially computing null count
104 *out
= Datum(state
.non_nulls
);
106 case CountOptions::ONLY_NULL
:
107 *out
= Datum(state
.nulls
);
110 DCHECK(false) << "unreachable";
115 CountOptions options
;
116 int64_t non_nulls
= 0;
120 Result
<std::unique_ptr
<KernelState
>> CountInit(KernelContext
*,
121 const KernelInitArgs
& args
) {
122 return ::arrow::internal::make_unique
<CountImpl
>(
123 static_cast<const CountOptions
&>(*args
.options
));
126 // ----------------------------------------------------------------------
127 // Distinct Count implementation
129 template <typename Type
, typename VisitorArgType
>
130 struct CountDistinctImpl
: public ScalarAggregator
{
131 using MemoTable
= typename
arrow::internal::HashTraits
<Type
>::MemoTableType
;
133 explicit CountDistinctImpl(MemoryPool
* memory_pool
, CountOptions options
)
134 : options(std::move(options
)), memo_table_(new MemoTable(memory_pool
, 0)) {}
136 Status
Consume(KernelContext
*, const ExecBatch
& batch
) override
{
137 if (batch
[0].is_array()) {
138 const ArrayData
& arr
= *batch
[0].array();
139 auto visit_null
= []() { return Status::OK(); };
140 auto visit_value
= [&](VisitorArgType arg
) {
142 return memo_table_
->GetOrInsert(arg
, &y
);
144 RETURN_NOT_OK(VisitArrayDataInline
<Type
>(arr
, visit_value
, visit_null
));
145 this->non_nulls
+= memo_table_
->size();
146 this->has_nulls
= arr
.GetNullCount() > 0;
148 const Scalar
& input
= *batch
[0].scalar();
149 this->has_nulls
= !input
.is_valid
;
150 if (input
.is_valid
) {
151 this->non_nulls
+= batch
.length
;
157 Status
MergeFrom(KernelContext
*, KernelState
&& src
) override
{
158 const auto& other_state
= checked_cast
<const CountDistinctImpl
&>(src
);
159 this->non_nulls
+= other_state
.non_nulls
;
160 this->has_nulls
= this->has_nulls
|| other_state
.has_nulls
;
164 Status
Finalize(KernelContext
* ctx
, Datum
* out
) override
{
165 const auto& state
= checked_cast
<const CountDistinctImpl
&>(*ctx
->state());
166 const int64_t nulls
= state
.has_nulls
? 1 : 0;
167 switch (state
.options
.mode
) {
168 case CountOptions::ONLY_VALID
:
169 *out
= Datum(state
.non_nulls
);
171 case CountOptions::ALL
:
172 *out
= Datum(state
.non_nulls
+ nulls
);
174 case CountOptions::ONLY_NULL
:
178 DCHECK(false) << "unreachable";
183 const CountOptions options
;
184 int64_t non_nulls
= 0;
185 bool has_nulls
= false;
186 std::unique_ptr
<MemoTable
> memo_table_
;
189 template <typename Type
, typename VisitorArgType
>
190 Result
<std::unique_ptr
<KernelState
>> CountDistinctInit(KernelContext
* ctx
,
191 const KernelInitArgs
& args
) {
192 return ::arrow::internal::make_unique
<CountDistinctImpl
<Type
, VisitorArgType
>>(
193 ctx
->memory_pool(), static_cast<const CountOptions
&>(*args
.options
));
196 template <typename Type
, typename VisitorArgType
= typename
Type::c_type
>
197 void AddCountDistinctKernel(InputType type
, ScalarAggregateFunction
* func
) {
198 AddAggKernel(KernelSignature::Make({type
}, ValueDescr::Scalar(int64())),
199 CountDistinctInit
<Type
, VisitorArgType
>, func
);
202 void AddCountDistinctKernels(ScalarAggregateFunction
* func
) {
204 AddCountDistinctKernel
<BooleanType
>(boolean(), func
);
206 AddCountDistinctKernel
<Int8Type
>(int8(), func
);
207 AddCountDistinctKernel
<Int16Type
>(int16(), func
);
208 AddCountDistinctKernel
<Int32Type
>(int32(), func
);
209 AddCountDistinctKernel
<Int64Type
>(int64(), func
);
210 AddCountDistinctKernel
<UInt8Type
>(uint8(), func
);
211 AddCountDistinctKernel
<UInt16Type
>(uint16(), func
);
212 AddCountDistinctKernel
<UInt32Type
>(uint32(), func
);
213 AddCountDistinctKernel
<UInt64Type
>(uint64(), func
);
214 AddCountDistinctKernel
<HalfFloatType
>(float16(), func
);
215 AddCountDistinctKernel
<FloatType
>(float32(), func
);
216 AddCountDistinctKernel
<DoubleType
>(float64(), func
);
218 AddCountDistinctKernel
<Date32Type
>(date32(), func
);
219 AddCountDistinctKernel
<Date64Type
>(date64(), func
);
221 AddCountDistinctKernel
<Time32Type
>(match::SameTypeId(Type::TIME32
), func
);
222 AddCountDistinctKernel
<Time64Type
>(match::SameTypeId(Type::TIME64
), func
);
223 // Timestamp & Duration
224 AddCountDistinctKernel
<TimestampType
>(match::SameTypeId(Type::TIMESTAMP
), func
);
225 AddCountDistinctKernel
<DurationType
>(match::SameTypeId(Type::DURATION
), func
);
227 AddCountDistinctKernel
<MonthIntervalType
>(month_interval(), func
);
228 AddCountDistinctKernel
<DayTimeIntervalType
>(day_time_interval(), func
);
229 AddCountDistinctKernel
<MonthDayNanoIntervalType
>(month_day_nano_interval(), func
);
231 AddCountDistinctKernel
<BinaryType
, util::string_view
>(match::BinaryLike(), func
);
232 AddCountDistinctKernel
<LargeBinaryType
, util::string_view
>(match::LargeBinaryLike(),
234 // Fixed binary & Decimal
235 AddCountDistinctKernel
<FixedSizeBinaryType
, util::string_view
>(
236 match::FixedSizeBinaryLike(), func
);
239 // ----------------------------------------------------------------------
240 // Sum implementation
242 template <typename ArrowType
>
243 struct SumImplDefault
: public SumImpl
<ArrowType
, SimdLevel::NONE
> {
244 using SumImpl
<ArrowType
, SimdLevel::NONE
>::SumImpl
;
247 template <typename ArrowType
>
248 struct MeanImplDefault
: public MeanImpl
<ArrowType
, SimdLevel::NONE
> {
249 using MeanImpl
<ArrowType
, SimdLevel::NONE
>::MeanImpl
;
252 Result
<std::unique_ptr
<KernelState
>> SumInit(KernelContext
* ctx
,
253 const KernelInitArgs
& args
) {
254 SumLikeInit
<SumImplDefault
> visitor(
255 ctx
, args
.inputs
[0].type
,
256 static_cast<const ScalarAggregateOptions
&>(*args
.options
));
257 return visitor
.Create();
260 Result
<std::unique_ptr
<KernelState
>> MeanInit(KernelContext
* ctx
,
261 const KernelInitArgs
& args
) {
262 SumLikeInit
<MeanImplDefault
> visitor(
263 ctx
, args
.inputs
[0].type
,
264 static_cast<const ScalarAggregateOptions
&>(*args
.options
));
265 return visitor
.Create();
268 // ----------------------------------------------------------------------
269 // Product implementation
271 using arrow::compute::internal::to_unsigned
;
273 template <typename ArrowType
>
274 struct ProductImpl
: public ScalarAggregator
{
275 using ThisType
= ProductImpl
<ArrowType
>;
276 using AccType
= typename FindAccumulatorType
<ArrowType
>::Type
;
277 using ProductType
= typename TypeTraits
<AccType
>::CType
;
278 using OutputType
= typename TypeTraits
<AccType
>::ScalarType
;
280 explicit ProductImpl(const std::shared_ptr
<DataType
>& out_type
,
281 const ScalarAggregateOptions
& options
)
282 : out_type(out_type
),
285 product(MultiplyTraits
<AccType
>::one(*out_type
)),
286 nulls_observed(false) {}
288 Status
Consume(KernelContext
*, const ExecBatch
& batch
) override
{
289 if (batch
[0].is_array()) {
290 const auto& data
= batch
[0].array();
291 this->count
+= data
->length
- data
->GetNullCount();
292 this->nulls_observed
= this->nulls_observed
|| data
->GetNullCount();
294 if (!options
.skip_nulls
&& this->nulls_observed
) {
299 internal::VisitArrayValuesInline
<ArrowType
>(
301 [&](typename TypeTraits
<ArrowType
>::CType value
) {
303 MultiplyTraits
<AccType
>::Multiply(*out_type
, this->product
, value
);
307 const auto& data
= *batch
[0].scalar();
308 this->count
+= data
.is_valid
* batch
.length
;
309 this->nulls_observed
= this->nulls_observed
|| !data
.is_valid
;
311 for (int64_t i
= 0; i
< batch
.length
; i
++) {
312 auto value
= internal::UnboxScalar
<ArrowType
>::Unbox(data
);
314 MultiplyTraits
<AccType
>::Multiply(*out_type
, this->product
, value
);
321 Status
MergeFrom(KernelContext
*, KernelState
&& src
) override
{
322 const auto& other
= checked_cast
<const ThisType
&>(src
);
323 this->count
+= other
.count
;
325 MultiplyTraits
<AccType
>::Multiply(*out_type
, this->product
, other
.product
);
326 this->nulls_observed
= this->nulls_observed
|| other
.nulls_observed
;
330 Status
Finalize(KernelContext
*, Datum
* out
) override
{
331 if ((!options
.skip_nulls
&& this->nulls_observed
) ||
332 (this->count
< options
.min_count
)) {
333 out
->value
= std::make_shared
<OutputType
>(out_type
);
335 out
->value
= std::make_shared
<OutputType
>(this->product
, out_type
);
340 std::shared_ptr
<DataType
> out_type
;
341 ScalarAggregateOptions options
;
348 std::unique_ptr
<KernelState
> state
;
350 const std::shared_ptr
<DataType
>& type
;
351 const ScalarAggregateOptions
& options
;
353 ProductInit(KernelContext
* ctx
, const std::shared_ptr
<DataType
>& type
,
354 const ScalarAggregateOptions
& options
)
355 : ctx(ctx
), type(type
), options(options
) {}
357 Status
Visit(const DataType
&) {
358 return Status::NotImplemented("No product implemented");
361 Status
Visit(const HalfFloatType
&) {
362 return Status::NotImplemented("No product implemented");
365 Status
Visit(const BooleanType
&) {
366 auto ty
= TypeTraits
<typename ProductImpl
<BooleanType
>::AccType
>::type_singleton();
367 state
.reset(new ProductImpl
<BooleanType
>(ty
, options
));
371 template <typename Type
>
372 enable_if_number
<Type
, Status
> Visit(const Type
&) {
373 auto ty
= TypeTraits
<typename ProductImpl
<Type
>::AccType
>::type_singleton();
374 state
.reset(new ProductImpl
<Type
>(ty
, options
));
378 template <typename Type
>
379 enable_if_decimal
<Type
, Status
> Visit(const Type
&) {
380 state
.reset(new ProductImpl
<Type
>(type
, options
));
384 Result
<std::unique_ptr
<KernelState
>> Create() {
385 RETURN_NOT_OK(VisitTypeInline(*type
, this));
386 return std::move(state
);
389 static Result
<std::unique_ptr
<KernelState
>> Init(KernelContext
* ctx
,
390 const KernelInitArgs
& args
) {
391 ProductInit
visitor(ctx
, args
.inputs
[0].type
,
392 static_cast<const ScalarAggregateOptions
&>(*args
.options
));
393 return visitor
.Create();
397 // ----------------------------------------------------------------------
398 // MinMax implementation
400 Result
<std::unique_ptr
<KernelState
>> MinMaxInit(KernelContext
* ctx
,
401 const KernelInitArgs
& args
) {
402 ARROW_ASSIGN_OR_RAISE(auto out_type
,
403 args
.kernel
->signature
->out_type().Resolve(ctx
, args
.inputs
));
404 MinMaxInitState
<SimdLevel::NONE
> visitor(
405 ctx
, *args
.inputs
[0].type
, std::move(out_type
.type
),
406 static_cast<const ScalarAggregateOptions
&>(*args
.options
));
407 return visitor
.Create();
410 // For "min" and "max" functions: override finalize and return the actual value
411 template <MinOrMax min_or_max
>
412 void AddMinOrMaxAggKernel(ScalarAggregateFunction
* func
,
413 ScalarAggregateFunction
* min_max_func
) {
414 auto sig
= KernelSignature::Make(
415 {InputType(ValueDescr::ANY
)},
416 OutputType([](KernelContext
*,
417 const std::vector
<ValueDescr
>& descrs
) -> Result
<ValueDescr
> {
418 // any[T] -> scalar[T]
419 return ValueDescr::Scalar(descrs
.front().type
);
422 auto init
= [min_max_func
](
424 const KernelInitArgs
& args
) -> Result
<std::unique_ptr
<KernelState
>> {
425 std::vector
<ValueDescr
> inputs
= args
.inputs
;
426 ARROW_ASSIGN_OR_RAISE(auto kernel
, min_max_func
->DispatchBest(&inputs
));
427 KernelInitArgs new_args
{kernel
, inputs
, args
.options
};
428 return kernel
->init(ctx
, new_args
);
431 auto finalize
= [](KernelContext
* ctx
, Datum
* out
) -> Status
{
433 RETURN_NOT_OK(checked_cast
<ScalarAggregator
*>(ctx
->state())->Finalize(ctx
, &temp
));
434 const auto& result
= temp
.scalar_as
<StructScalar
>();
435 DCHECK(result
.is_valid
);
436 *out
= result
.value
[static_cast<uint8_t>(min_or_max
)];
440 // Note SIMD level is always NONE, but the convenience kernel will
441 // dispatch to an appropriate implementation
442 AddAggKernel(std::move(sig
), std::move(init
), std::move(finalize
), func
);
445 // ----------------------------------------------------------------------
446 // Any implementation
448 struct BooleanAnyImpl
: public ScalarAggregator
{
449 explicit BooleanAnyImpl(ScalarAggregateOptions options
) : options(std::move(options
)) {}
451 Status
Consume(KernelContext
*, const ExecBatch
& batch
) override
{
452 // short-circuit if seen a True already
453 if (this->any
== true && this->count
>= options
.min_count
) {
456 if (batch
[0].is_scalar()) {
457 const auto& scalar
= *batch
[0].scalar();
458 this->has_nulls
= !scalar
.is_valid
;
459 this->any
= scalar
.is_valid
&& checked_cast
<const BooleanScalar
&>(scalar
).value
;
460 this->count
+= scalar
.is_valid
;
463 const auto& data
= *batch
[0].array();
464 this->has_nulls
= data
.GetNullCount() > 0;
465 this->count
+= data
.length
- data
.GetNullCount();
466 arrow::internal::OptionalBinaryBitBlockCounter
counter(
467 data
.buffers
[0], data
.offset
, data
.buffers
[1], data
.offset
, data
.length
);
468 int64_t position
= 0;
469 while (position
< data
.length
) {
470 const auto block
= counter
.NextAndBlock();
471 if (block
.popcount
> 0) {
475 position
+= block
.length
;
480 Status
MergeFrom(KernelContext
*, KernelState
&& src
) override
{
481 const auto& other
= checked_cast
<const BooleanAnyImpl
&>(src
);
482 this->any
|= other
.any
;
483 this->has_nulls
|= other
.has_nulls
;
484 this->count
+= other
.count
;
488 Status
Finalize(KernelContext
* ctx
, Datum
* out
) override
{
489 if ((!options
.skip_nulls
&& !this->any
&& this->has_nulls
) ||
490 this->count
< options
.min_count
) {
491 out
->value
= std::make_shared
<BooleanScalar
>();
493 out
->value
= std::make_shared
<BooleanScalar
>(this->any
);
499 bool has_nulls
= false;
501 ScalarAggregateOptions options
;
504 Result
<std::unique_ptr
<KernelState
>> AnyInit(KernelContext
*, const KernelInitArgs
& args
) {
505 const ScalarAggregateOptions options
=
506 static_cast<const ScalarAggregateOptions
&>(*args
.options
);
507 return ::arrow::internal::make_unique
<BooleanAnyImpl
>(
508 static_cast<const ScalarAggregateOptions
&>(*args
.options
));
511 // ----------------------------------------------------------------------
512 // All implementation
514 struct BooleanAllImpl
: public ScalarAggregator
{
515 explicit BooleanAllImpl(ScalarAggregateOptions options
) : options(std::move(options
)) {}
517 Status
Consume(KernelContext
*, const ExecBatch
& batch
) override
{
518 // short-circuit if seen a false already
519 if (this->all
== false && this->count
>= options
.min_count
) {
522 // short-circuit if seen a null already
523 if (!options
.skip_nulls
&& this->has_nulls
) {
526 if (batch
[0].is_scalar()) {
527 const auto& scalar
= *batch
[0].scalar();
528 this->has_nulls
= !scalar
.is_valid
;
529 this->count
+= scalar
.is_valid
;
530 this->all
= !scalar
.is_valid
|| checked_cast
<const BooleanScalar
&>(scalar
).value
;
533 const auto& data
= *batch
[0].array();
534 this->has_nulls
= data
.GetNullCount() > 0;
535 this->count
+= data
.length
- data
.GetNullCount();
536 arrow::internal::OptionalBinaryBitBlockCounter
counter(
537 data
.buffers
[1], data
.offset
, data
.buffers
[0], data
.offset
, data
.length
);
538 int64_t position
= 0;
539 while (position
< data
.length
) {
540 const auto block
= counter
.NextOrNotBlock();
541 if (!block
.AllSet()) {
545 position
+= block
.length
;
551 Status
MergeFrom(KernelContext
*, KernelState
&& src
) override
{
552 const auto& other
= checked_cast
<const BooleanAllImpl
&>(src
);
553 this->all
&= other
.all
;
554 this->has_nulls
|= other
.has_nulls
;
555 this->count
+= other
.count
;
559 Status
Finalize(KernelContext
*, Datum
* out
) override
{
560 if ((!options
.skip_nulls
&& this->all
&& this->has_nulls
) ||
561 this->count
< options
.min_count
) {
562 out
->value
= std::make_shared
<BooleanScalar
>();
564 out
->value
= std::make_shared
<BooleanScalar
>(this->all
);
570 bool has_nulls
= false;
572 ScalarAggregateOptions options
;
575 Result
<std::unique_ptr
<KernelState
>> AllInit(KernelContext
*, const KernelInitArgs
& args
) {
576 return ::arrow::internal::make_unique
<BooleanAllImpl
>(
577 static_cast<const ScalarAggregateOptions
&>(*args
.options
));
580 // ----------------------------------------------------------------------
581 // Index implementation
583 template <typename ArgType
>
584 struct IndexImpl
: public ScalarAggregator
{
585 using ArgValue
= typename
internal::GetViewType
<ArgType
>::T
;
587 explicit IndexImpl(IndexOptions options
, KernelState
* raw_state
)
588 : options(std::move(options
)), seen(0), index(-1) {
589 if (auto state
= static_cast<IndexImpl
<ArgType
>*>(raw_state
)) {
591 index
= state
->index
;
595 Status
Consume(KernelContext
* ctx
, const ExecBatch
& batch
) override
{
597 if (index
>= 0 || !options
.value
->is_valid
) {
601 const ArgValue desired
= internal::UnboxScalar
<ArgType
>::Unbox(*options
.value
);
603 if (batch
[0].is_scalar()) {
605 if (batch
[0].scalar()->is_valid
) {
606 const ArgValue v
= internal::UnboxScalar
<ArgType
>::Unbox(*batch
[0].scalar());
609 return Status::Cancelled("Found");
615 auto input
= batch
[0].array();
616 seen
= input
->length
;
619 ARROW_UNUSED(internal::VisitArrayValuesInline
<ArgType
>(
621 [&](ArgValue v
) -> Status
{
624 return Status::Cancelled("Found");
638 Status
MergeFrom(KernelContext
*, KernelState
&& src
) override
{
639 const auto& other
= checked_cast
<const IndexImpl
&>(src
);
640 if (index
< 0 && other
.index
>= 0) {
641 index
= seen
+ other
.index
;
647 Status
Finalize(KernelContext
*, Datum
* out
) override
{
648 out
->value
= std::make_shared
<Int64Scalar
>(index
>= 0 ? index
: -1);
652 const IndexOptions options
;
658 std::unique_ptr
<KernelState
> state
;
660 const IndexOptions
& options
;
661 const DataType
& type
;
663 IndexInit(KernelContext
* ctx
, const IndexOptions
& options
, const DataType
& type
)
664 : ctx(ctx
), options(options
), type(type
) {}
666 Status
Visit(const DataType
& type
) {
667 return Status::NotImplemented("Index kernel not implemented for ", type
.ToString());
670 Status
Visit(const BooleanType
&) {
671 state
.reset(new IndexImpl
<BooleanType
>(options
, ctx
->state()));
675 template <typename Type
>
676 enable_if_number
<Type
, Status
> Visit(const Type
&) {
677 state
.reset(new IndexImpl
<Type
>(options
, ctx
->state()));
681 template <typename Type
>
682 enable_if_base_binary
<Type
, Status
> Visit(const Type
&) {
683 state
.reset(new IndexImpl
<Type
>(options
, ctx
->state()));
687 template <typename Type
>
688 enable_if_date
<Type
, Status
> Visit(const Type
&) {
689 state
.reset(new IndexImpl
<Type
>(options
, ctx
->state()));
693 template <typename Type
>
694 enable_if_time
<Type
, Status
> Visit(const Type
&) {
695 state
.reset(new IndexImpl
<Type
>(options
, ctx
->state()));
699 template <typename Type
>
700 enable_if_timestamp
<Type
, Status
> Visit(const Type
&) {
701 state
.reset(new IndexImpl
<Type
>(options
, ctx
->state()));
705 Result
<std::unique_ptr
<KernelState
>> Create() {
706 RETURN_NOT_OK(VisitTypeInline(type
, this));
707 return std::move(state
);
710 static Result
<std::unique_ptr
<KernelState
>> Init(KernelContext
* ctx
,
711 const KernelInitArgs
& args
) {
713 return Status::Invalid("Must provide IndexOptions for index kernel");
715 IndexInit
visitor(ctx
, static_cast<const IndexOptions
&>(*args
.options
),
716 *args
.inputs
[0].type
);
717 return visitor
.Create();
723 void AddBasicAggKernels(KernelInit init
,
724 const std::vector
<std::shared_ptr
<DataType
>>& types
,
725 std::shared_ptr
<DataType
> out_ty
, ScalarAggregateFunction
* func
,
726 SimdLevel::type simd_level
) {
727 for (const auto& ty
: types
) {
728 // array[InT] -> scalar[OutT]
730 KernelSignature::Make({InputType::Array(ty
->id())}, ValueDescr::Scalar(out_ty
));
731 AddAggKernel(std::move(sig
), init
, func
, simd_level
);
735 void AddScalarAggKernels(KernelInit init
,
736 const std::vector
<std::shared_ptr
<DataType
>>& types
,
737 std::shared_ptr
<DataType
> out_ty
,
738 ScalarAggregateFunction
* func
) {
739 for (const auto& ty
: types
) {
740 // scalar[InT] -> scalar[OutT]
742 KernelSignature::Make({InputType::Scalar(ty
->id())}, ValueDescr::Scalar(out_ty
));
743 AddAggKernel(std::move(sig
), init
, func
, SimdLevel::NONE
);
747 void AddArrayScalarAggKernels(KernelInit init
,
748 const std::vector
<std::shared_ptr
<DataType
>>& types
,
749 std::shared_ptr
<DataType
> out_ty
,
750 ScalarAggregateFunction
* func
,
751 SimdLevel::type simd_level
= SimdLevel::NONE
) {
752 AddBasicAggKernels(init
, types
, out_ty
, func
, simd_level
);
753 AddScalarAggKernels(init
, types
, out_ty
, func
);
758 Result
<ValueDescr
> MinMaxType(KernelContext
*, const std::vector
<ValueDescr
>& descrs
) {
759 // any[T] -> scalar[struct<min: T, max: T>]
760 auto ty
= descrs
.front().type
;
761 return ValueDescr::Scalar(struct_({field("min", ty
), field("max", ty
)}));
766 void AddMinMaxKernel(KernelInit init
, internal::detail::GetTypeId get_id
,
767 ScalarAggregateFunction
* func
, SimdLevel::type simd_level
) {
768 auto sig
= KernelSignature::Make({InputType(get_id
.id
)}, OutputType(MinMaxType
));
769 AddAggKernel(std::move(sig
), init
, func
, simd_level
);
772 void AddMinMaxKernels(KernelInit init
,
773 const std::vector
<std::shared_ptr
<DataType
>>& types
,
774 ScalarAggregateFunction
* func
, SimdLevel::type simd_level
) {
775 for (const auto& ty
: types
) {
776 AddMinMaxKernel(init
, ty
, func
, simd_level
);
782 Result
<ValueDescr
> ScalarFirstType(KernelContext
*,
783 const std::vector
<ValueDescr
>& descrs
) {
784 ValueDescr result
= descrs
.front();
785 result
.shape
= ValueDescr::SCALAR
;
789 const FunctionDoc count_doc
{"Count the number of null / non-null values",
790 ("By default, only non-null values are counted.\n"
791 "This can be changed through CountOptions."),
795 const FunctionDoc count_distinct_doc
{"Count the number of unique values",
796 ("By default, only non-null values are counted.\n"
797 "This can be changed through CountOptions."),
801 const FunctionDoc sum_doc
{
802 "Compute the sum of a numeric array",
803 ("Null values are ignored by default. Minimum count of non-null\n"
804 "values can be set and null is returned if too few are present.\n"
805 "This can be changed through ScalarAggregateOptions."),
807 "ScalarAggregateOptions"};
809 const FunctionDoc product_doc
{
810 "Compute the product of values in a numeric array",
811 ("Null values are ignored by default. Minimum count of non-null\n"
812 "values can be set and null is returned if too few are present.\n"
813 "This can be changed through ScalarAggregateOptions."),
815 "ScalarAggregateOptions"};
817 const FunctionDoc mean_doc
{
818 "Compute the mean of a numeric array",
819 ("Null values are ignored by default. Minimum count of non-null\n"
820 "values can be set and null is returned if too few are present.\n"
821 "This can be changed through ScalarAggregateOptions.\n"
822 "The result is a double for integer and floating point arguments,\n"
823 "and a decimal with the same bit-width/precision/scale for decimal arguments.\n"
824 "For integers and floats, NaN is returned if min_count = 0 and\n"
825 "there are no values. For decimals, null is returned instead."),
827 "ScalarAggregateOptions"};
829 const FunctionDoc min_max_doc
{"Compute the minimum and maximum values of a numeric array",
830 ("Null values are ignored by default.\n"
831 "This can be changed through ScalarAggregateOptions."),
833 "ScalarAggregateOptions"};
835 const FunctionDoc min_or_max_doc
{
836 "Compute the minimum or maximum values of a numeric array",
837 ("Null values are ignored by default.\n"
838 "This can be changed through ScalarAggregateOptions."),
840 "ScalarAggregateOptions"};
842 const FunctionDoc any_doc
{"Test whether any element in a boolean array evaluates to true",
843 ("Null values are ignored by default.\n"
844 "If null values are taken into account by setting "
845 "ScalarAggregateOptions parameter skip_nulls = false then "
846 "Kleene logic is used.\n"
847 "See KleeneOr for more details on Kleene logic."),
849 "ScalarAggregateOptions"};
851 const FunctionDoc all_doc
{"Test whether all elements in a boolean array evaluate to true",
852 ("Null values are ignored by default.\n"
853 "If null values are taken into account by setting "
854 "ScalarAggregateOptions parameter skip_nulls = false then "
855 "Kleene logic is used.\n"
856 "See KleeneAnd for more details on Kleene logic."),
858 "ScalarAggregateOptions"};
860 const FunctionDoc index_doc
{"Find the index of the first occurrence of a given value",
861 ("The result is always computed as an int64_t, regardless\n"
862 "of the offset type of the input array."),
868 void RegisterScalarAggregateBasic(FunctionRegistry
* registry
) {
869 static auto default_scalar_aggregate_options
= ScalarAggregateOptions::Defaults();
870 static auto default_count_options
= CountOptions::Defaults();
872 auto func
= std::make_shared
<ScalarAggregateFunction
>(
873 "count", Arity::Unary(), &count_doc
, &default_count_options
);
875 // Takes any input, outputs int64 scalar
877 AddAggKernel(KernelSignature::Make({any_input
}, ValueDescr::Scalar(int64())), CountInit
,
879 DCHECK_OK(registry
->AddFunction(std::move(func
)));
881 func
= std::make_shared
<ScalarAggregateFunction
>(
882 "count_distinct", Arity::Unary(), &count_distinct_doc
, &default_count_options
);
883 // Takes any input, outputs int64 scalar
884 AddCountDistinctKernels(func
.get());
885 DCHECK_OK(registry
->AddFunction(std::move(func
)));
887 func
= std::make_shared
<ScalarAggregateFunction
>("sum", Arity::Unary(), &sum_doc
,
888 &default_scalar_aggregate_options
);
889 AddArrayScalarAggKernels(SumInit
, {boolean()}, uint64(), func
.get());
891 KernelSignature::Make({InputType(Type::DECIMAL128
)}, OutputType(ScalarFirstType
)),
892 SumInit
, func
.get(), SimdLevel::NONE
);
894 KernelSignature::Make({InputType(Type::DECIMAL256
)}, OutputType(ScalarFirstType
)),
895 SumInit
, func
.get(), SimdLevel::NONE
);
896 AddArrayScalarAggKernels(SumInit
, SignedIntTypes(), int64(), func
.get());
897 AddArrayScalarAggKernels(SumInit
, UnsignedIntTypes(), uint64(), func
.get());
898 AddArrayScalarAggKernels(SumInit
, FloatingPointTypes(), float64(), func
.get());
899 // Add the SIMD variants for sum
900 #if defined(ARROW_HAVE_RUNTIME_AVX2) || defined(ARROW_HAVE_RUNTIME_AVX512)
901 auto cpu_info
= arrow::internal::CpuInfo::GetInstance();
903 #if defined(ARROW_HAVE_RUNTIME_AVX2)
904 if (cpu_info
->IsSupported(arrow::internal::CpuInfo::AVX2
)) {
905 AddSumAvx2AggKernels(func
.get());
908 #if defined(ARROW_HAVE_RUNTIME_AVX512)
909 if (cpu_info
->IsSupported(arrow::internal::CpuInfo::AVX512
)) {
910 AddSumAvx512AggKernels(func
.get());
913 DCHECK_OK(registry
->AddFunction(std::move(func
)));
915 func
= std::make_shared
<ScalarAggregateFunction
>("mean", Arity::Unary(), &mean_doc
,
916 &default_scalar_aggregate_options
);
917 AddArrayScalarAggKernels(MeanInit
, {boolean()}, float64(), func
.get());
918 AddArrayScalarAggKernels(MeanInit
, NumericTypes(), float64(), func
.get());
920 KernelSignature::Make({InputType(Type::DECIMAL128
)}, OutputType(ScalarFirstType
)),
921 MeanInit
, func
.get(), SimdLevel::NONE
);
923 KernelSignature::Make({InputType(Type::DECIMAL256
)}, OutputType(ScalarFirstType
)),
924 MeanInit
, func
.get(), SimdLevel::NONE
);
925 // Add the SIMD variants for mean
926 #if defined(ARROW_HAVE_RUNTIME_AVX2)
927 if (cpu_info
->IsSupported(arrow::internal::CpuInfo::AVX2
)) {
928 AddMeanAvx2AggKernels(func
.get());
931 #if defined(ARROW_HAVE_RUNTIME_AVX512)
932 if (cpu_info
->IsSupported(arrow::internal::CpuInfo::AVX512
)) {
933 AddMeanAvx512AggKernels(func
.get());
936 DCHECK_OK(registry
->AddFunction(std::move(func
)));
938 func
= std::make_shared
<ScalarAggregateFunction
>(
939 "min_max", Arity::Unary(), &min_max_doc
, &default_scalar_aggregate_options
);
940 AddMinMaxKernels(MinMaxInit
, {null(), boolean()}, func
.get());
941 AddMinMaxKernels(MinMaxInit
, NumericTypes(), func
.get());
942 AddMinMaxKernels(MinMaxInit
, TemporalTypes(), func
.get());
943 AddMinMaxKernels(MinMaxInit
, BaseBinaryTypes(), func
.get());
944 AddMinMaxKernel(MinMaxInit
, Type::FIXED_SIZE_BINARY
, func
.get());
945 AddMinMaxKernel(MinMaxInit
, Type::INTERVAL_MONTHS
, func
.get());
946 AddMinMaxKernel(MinMaxInit
, Type::DECIMAL128
, func
.get());
947 AddMinMaxKernel(MinMaxInit
, Type::DECIMAL256
, func
.get());
948 // Add the SIMD variants for min max
949 #if defined(ARROW_HAVE_RUNTIME_AVX2)
950 if (cpu_info
->IsSupported(arrow::internal::CpuInfo::AVX2
)) {
951 AddMinMaxAvx2AggKernels(func
.get());
954 #if defined(ARROW_HAVE_RUNTIME_AVX512)
955 if (cpu_info
->IsSupported(arrow::internal::CpuInfo::AVX512
)) {
956 AddMinMaxAvx512AggKernels(func
.get());
960 auto min_max_func
= func
.get();
961 DCHECK_OK(registry
->AddFunction(std::move(func
)));
963 // Add min/max as convenience functions
964 func
= std::make_shared
<ScalarAggregateFunction
>("min", Arity::Unary(), &min_or_max_doc
,
965 &default_scalar_aggregate_options
);
966 AddMinOrMaxAggKernel
<MinOrMax::Min
>(func
.get(), min_max_func
);
967 DCHECK_OK(registry
->AddFunction(std::move(func
)));
969 func
= std::make_shared
<ScalarAggregateFunction
>("max", Arity::Unary(), &min_or_max_doc
,
970 &default_scalar_aggregate_options
);
971 AddMinOrMaxAggKernel
<MinOrMax::Max
>(func
.get(), min_max_func
);
972 DCHECK_OK(registry
->AddFunction(std::move(func
)));
974 func
= std::make_shared
<ScalarAggregateFunction
>(
975 "product", Arity::Unary(), &product_doc
, &default_scalar_aggregate_options
);
976 AddArrayScalarAggKernels(ProductInit::Init
, {boolean()}, uint64(), func
.get());
977 AddArrayScalarAggKernels(ProductInit::Init
, SignedIntTypes(), int64(), func
.get());
978 AddArrayScalarAggKernels(ProductInit::Init
, UnsignedIntTypes(), uint64(), func
.get());
979 AddArrayScalarAggKernels(ProductInit::Init
, FloatingPointTypes(), float64(),
982 KernelSignature::Make({InputType(Type::DECIMAL128
)}, OutputType(ScalarFirstType
)),
983 ProductInit::Init
, func
.get(), SimdLevel::NONE
);
985 KernelSignature::Make({InputType(Type::DECIMAL256
)}, OutputType(ScalarFirstType
)),
986 ProductInit::Init
, func
.get(), SimdLevel::NONE
);
987 DCHECK_OK(registry
->AddFunction(std::move(func
)));
990 func
= std::make_shared
<ScalarAggregateFunction
>("any", Arity::Unary(), &any_doc
,
991 &default_scalar_aggregate_options
);
992 AddArrayScalarAggKernels(AnyInit
, {boolean()}, boolean(), func
.get());
993 DCHECK_OK(registry
->AddFunction(std::move(func
)));
996 func
= std::make_shared
<ScalarAggregateFunction
>("all", Arity::Unary(), &all_doc
,
997 &default_scalar_aggregate_options
);
998 AddArrayScalarAggKernels(AllInit
, {boolean()}, boolean(), func
.get());
999 DCHECK_OK(registry
->AddFunction(std::move(func
)));
1002 func
= std::make_shared
<ScalarAggregateFunction
>("index", Arity::Unary(), &index_doc
);
1003 AddBasicAggKernels(IndexInit::Init
, BaseBinaryTypes(), int64(), func
.get());
1004 AddBasicAggKernels(IndexInit::Init
, PrimitiveTypes(), int64(), func
.get());
1005 AddBasicAggKernels(IndexInit::Init
, TemporalTypes(), int64(), func
.get());
1006 DCHECK_OK(registry
->AddFunction(std::move(func
)));
1009 } // namespace internal
1010 } // namespace compute
1011 } // namespace arrow