]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / kernels / aggregate_basic.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/compute/api_aggregate.h"
19 #include "arrow/compute/kernels/aggregate_basic_internal.h"
20 #include "arrow/compute/kernels/aggregate_internal.h"
21 #include "arrow/compute/kernels/common.h"
22 #include "arrow/compute/kernels/util_internal.h"
23 #include "arrow/util/cpu_info.h"
24 #include "arrow/util/hashing.h"
25 #include "arrow/util/make_unique.h"
26
27 namespace arrow {
28 namespace compute {
29 namespace internal {
30
31 namespace {
32
33 Status AggregateConsume(KernelContext* ctx, const ExecBatch& batch) {
34 return checked_cast<ScalarAggregator*>(ctx->state())->Consume(ctx, batch);
35 }
36
37 Status AggregateMerge(KernelContext* ctx, KernelState&& src, KernelState* dst) {
38 return checked_cast<ScalarAggregator*>(dst)->MergeFrom(ctx, std::move(src));
39 }
40
41 Status AggregateFinalize(KernelContext* ctx, Datum* out) {
42 return checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, out);
43 }
44
45 } // namespace
46
47 void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
48 ScalarAggregateFunction* func, SimdLevel::type simd_level) {
49 ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume,
50 AggregateMerge, AggregateFinalize);
51 // Set the simd level
52 kernel.simd_level = simd_level;
53 DCHECK_OK(func->AddKernel(std::move(kernel)));
54 }
55
56 void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
57 ScalarAggregateFinalize finalize, ScalarAggregateFunction* func,
58 SimdLevel::type simd_level) {
59 ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume,
60 AggregateMerge, std::move(finalize));
61 // Set the simd level
62 kernel.simd_level = simd_level;
63 DCHECK_OK(func->AddKernel(std::move(kernel)));
64 }
65
66 namespace {
67
68 // ----------------------------------------------------------------------
69 // Count implementation
70
71 struct CountImpl : public ScalarAggregator {
72 explicit CountImpl(CountOptions options) : options(std::move(options)) {}
73
74 Status Consume(KernelContext*, const ExecBatch& batch) override {
75 if (options.mode == CountOptions::ALL) {
76 this->non_nulls += batch.length;
77 } else if (batch[0].is_array()) {
78 const ArrayData& input = *batch[0].array();
79 const int64_t nulls = input.GetNullCount();
80 this->nulls += nulls;
81 this->non_nulls += input.length - nulls;
82 } else {
83 const Scalar& input = *batch[0].scalar();
84 this->nulls += !input.is_valid * batch.length;
85 this->non_nulls += input.is_valid * batch.length;
86 }
87 return Status::OK();
88 }
89
90 Status MergeFrom(KernelContext*, KernelState&& src) override {
91 const auto& other_state = checked_cast<const CountImpl&>(src);
92 this->non_nulls += other_state.non_nulls;
93 this->nulls += other_state.nulls;
94 return Status::OK();
95 }
96
97 Status Finalize(KernelContext* ctx, Datum* out) override {
98 const auto& state = checked_cast<const CountImpl&>(*ctx->state());
99 switch (state.options.mode) {
100 case CountOptions::ONLY_VALID:
101 case CountOptions::ALL:
102 // ALL is equivalent since we don't count the null/non-null
103 // separately to avoid potentially computing null count
104 *out = Datum(state.non_nulls);
105 break;
106 case CountOptions::ONLY_NULL:
107 *out = Datum(state.nulls);
108 break;
109 default:
110 DCHECK(false) << "unreachable";
111 }
112 return Status::OK();
113 }
114
115 CountOptions options;
116 int64_t non_nulls = 0;
117 int64_t nulls = 0;
118 };
119
120 Result<std::unique_ptr<KernelState>> CountInit(KernelContext*,
121 const KernelInitArgs& args) {
122 return ::arrow::internal::make_unique<CountImpl>(
123 static_cast<const CountOptions&>(*args.options));
124 }
125
126 // ----------------------------------------------------------------------
127 // Distinct Count implementation
128
129 template <typename Type, typename VisitorArgType>
130 struct CountDistinctImpl : public ScalarAggregator {
131 using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
132
133 explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
134 : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 0)) {}
135
136 Status Consume(KernelContext*, const ExecBatch& batch) override {
137 if (batch[0].is_array()) {
138 const ArrayData& arr = *batch[0].array();
139 auto visit_null = []() { return Status::OK(); };
140 auto visit_value = [&](VisitorArgType arg) {
141 int y;
142 return memo_table_->GetOrInsert(arg, &y);
143 };
144 RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
145 this->non_nulls += memo_table_->size();
146 this->has_nulls = arr.GetNullCount() > 0;
147 } else {
148 const Scalar& input = *batch[0].scalar();
149 this->has_nulls = !input.is_valid;
150 if (input.is_valid) {
151 this->non_nulls += batch.length;
152 }
153 }
154 return Status::OK();
155 }
156
157 Status MergeFrom(KernelContext*, KernelState&& src) override {
158 const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
159 this->non_nulls += other_state.non_nulls;
160 this->has_nulls = this->has_nulls || other_state.has_nulls;
161 return Status::OK();
162 }
163
164 Status Finalize(KernelContext* ctx, Datum* out) override {
165 const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state());
166 const int64_t nulls = state.has_nulls ? 1 : 0;
167 switch (state.options.mode) {
168 case CountOptions::ONLY_VALID:
169 *out = Datum(state.non_nulls);
170 break;
171 case CountOptions::ALL:
172 *out = Datum(state.non_nulls + nulls);
173 break;
174 case CountOptions::ONLY_NULL:
175 *out = Datum(nulls);
176 break;
177 default:
178 DCHECK(false) << "unreachable";
179 }
180 return Status::OK();
181 }
182
183 const CountOptions options;
184 int64_t non_nulls = 0;
185 bool has_nulls = false;
186 std::unique_ptr<MemoTable> memo_table_;
187 };
188
189 template <typename Type, typename VisitorArgType>
190 Result<std::unique_ptr<KernelState>> CountDistinctInit(KernelContext* ctx,
191 const KernelInitArgs& args) {
192 return ::arrow::internal::make_unique<CountDistinctImpl<Type, VisitorArgType>>(
193 ctx->memory_pool(), static_cast<const CountOptions&>(*args.options));
194 }
195
196 template <typename Type, typename VisitorArgType = typename Type::c_type>
197 void AddCountDistinctKernel(InputType type, ScalarAggregateFunction* func) {
198 AddAggKernel(KernelSignature::Make({type}, ValueDescr::Scalar(int64())),
199 CountDistinctInit<Type, VisitorArgType>, func);
200 }
201
202 void AddCountDistinctKernels(ScalarAggregateFunction* func) {
203 // Boolean
204 AddCountDistinctKernel<BooleanType>(boolean(), func);
205 // Number
206 AddCountDistinctKernel<Int8Type>(int8(), func);
207 AddCountDistinctKernel<Int16Type>(int16(), func);
208 AddCountDistinctKernel<Int32Type>(int32(), func);
209 AddCountDistinctKernel<Int64Type>(int64(), func);
210 AddCountDistinctKernel<UInt8Type>(uint8(), func);
211 AddCountDistinctKernel<UInt16Type>(uint16(), func);
212 AddCountDistinctKernel<UInt32Type>(uint32(), func);
213 AddCountDistinctKernel<UInt64Type>(uint64(), func);
214 AddCountDistinctKernel<HalfFloatType>(float16(), func);
215 AddCountDistinctKernel<FloatType>(float32(), func);
216 AddCountDistinctKernel<DoubleType>(float64(), func);
217 // Date
218 AddCountDistinctKernel<Date32Type>(date32(), func);
219 AddCountDistinctKernel<Date64Type>(date64(), func);
220 // Time
221 AddCountDistinctKernel<Time32Type>(match::SameTypeId(Type::TIME32), func);
222 AddCountDistinctKernel<Time64Type>(match::SameTypeId(Type::TIME64), func);
223 // Timestamp & Duration
224 AddCountDistinctKernel<TimestampType>(match::SameTypeId(Type::TIMESTAMP), func);
225 AddCountDistinctKernel<DurationType>(match::SameTypeId(Type::DURATION), func);
226 // Interval
227 AddCountDistinctKernel<MonthIntervalType>(month_interval(), func);
228 AddCountDistinctKernel<DayTimeIntervalType>(day_time_interval(), func);
229 AddCountDistinctKernel<MonthDayNanoIntervalType>(month_day_nano_interval(), func);
230 // Binary & String
231 AddCountDistinctKernel<BinaryType, util::string_view>(match::BinaryLike(), func);
232 AddCountDistinctKernel<LargeBinaryType, util::string_view>(match::LargeBinaryLike(),
233 func);
234 // Fixed binary & Decimal
235 AddCountDistinctKernel<FixedSizeBinaryType, util::string_view>(
236 match::FixedSizeBinaryLike(), func);
237 }
238
239 // ----------------------------------------------------------------------
240 // Sum implementation
241
242 template <typename ArrowType>
243 struct SumImplDefault : public SumImpl<ArrowType, SimdLevel::NONE> {
244 using SumImpl<ArrowType, SimdLevel::NONE>::SumImpl;
245 };
246
247 template <typename ArrowType>
248 struct MeanImplDefault : public MeanImpl<ArrowType, SimdLevel::NONE> {
249 using MeanImpl<ArrowType, SimdLevel::NONE>::MeanImpl;
250 };
251
252 Result<std::unique_ptr<KernelState>> SumInit(KernelContext* ctx,
253 const KernelInitArgs& args) {
254 SumLikeInit<SumImplDefault> visitor(
255 ctx, args.inputs[0].type,
256 static_cast<const ScalarAggregateOptions&>(*args.options));
257 return visitor.Create();
258 }
259
260 Result<std::unique_ptr<KernelState>> MeanInit(KernelContext* ctx,
261 const KernelInitArgs& args) {
262 SumLikeInit<MeanImplDefault> visitor(
263 ctx, args.inputs[0].type,
264 static_cast<const ScalarAggregateOptions&>(*args.options));
265 return visitor.Create();
266 }
267
268 // ----------------------------------------------------------------------
269 // Product implementation
270
271 using arrow::compute::internal::to_unsigned;
272
273 template <typename ArrowType>
274 struct ProductImpl : public ScalarAggregator {
275 using ThisType = ProductImpl<ArrowType>;
276 using AccType = typename FindAccumulatorType<ArrowType>::Type;
277 using ProductType = typename TypeTraits<AccType>::CType;
278 using OutputType = typename TypeTraits<AccType>::ScalarType;
279
280 explicit ProductImpl(const std::shared_ptr<DataType>& out_type,
281 const ScalarAggregateOptions& options)
282 : out_type(out_type),
283 options(options),
284 count(0),
285 product(MultiplyTraits<AccType>::one(*out_type)),
286 nulls_observed(false) {}
287
288 Status Consume(KernelContext*, const ExecBatch& batch) override {
289 if (batch[0].is_array()) {
290 const auto& data = batch[0].array();
291 this->count += data->length - data->GetNullCount();
292 this->nulls_observed = this->nulls_observed || data->GetNullCount();
293
294 if (!options.skip_nulls && this->nulls_observed) {
295 // Short-circuit
296 return Status::OK();
297 }
298
299 internal::VisitArrayValuesInline<ArrowType>(
300 *data,
301 [&](typename TypeTraits<ArrowType>::CType value) {
302 this->product =
303 MultiplyTraits<AccType>::Multiply(*out_type, this->product, value);
304 },
305 [] {});
306 } else {
307 const auto& data = *batch[0].scalar();
308 this->count += data.is_valid * batch.length;
309 this->nulls_observed = this->nulls_observed || !data.is_valid;
310 if (data.is_valid) {
311 for (int64_t i = 0; i < batch.length; i++) {
312 auto value = internal::UnboxScalar<ArrowType>::Unbox(data);
313 this->product =
314 MultiplyTraits<AccType>::Multiply(*out_type, this->product, value);
315 }
316 }
317 }
318 return Status::OK();
319 }
320
321 Status MergeFrom(KernelContext*, KernelState&& src) override {
322 const auto& other = checked_cast<const ThisType&>(src);
323 this->count += other.count;
324 this->product =
325 MultiplyTraits<AccType>::Multiply(*out_type, this->product, other.product);
326 this->nulls_observed = this->nulls_observed || other.nulls_observed;
327 return Status::OK();
328 }
329
330 Status Finalize(KernelContext*, Datum* out) override {
331 if ((!options.skip_nulls && this->nulls_observed) ||
332 (this->count < options.min_count)) {
333 out->value = std::make_shared<OutputType>(out_type);
334 } else {
335 out->value = std::make_shared<OutputType>(this->product, out_type);
336 }
337 return Status::OK();
338 }
339
340 std::shared_ptr<DataType> out_type;
341 ScalarAggregateOptions options;
342 size_t count;
343 ProductType product;
344 bool nulls_observed;
345 };
346
347 struct ProductInit {
348 std::unique_ptr<KernelState> state;
349 KernelContext* ctx;
350 const std::shared_ptr<DataType>& type;
351 const ScalarAggregateOptions& options;
352
353 ProductInit(KernelContext* ctx, const std::shared_ptr<DataType>& type,
354 const ScalarAggregateOptions& options)
355 : ctx(ctx), type(type), options(options) {}
356
357 Status Visit(const DataType&) {
358 return Status::NotImplemented("No product implemented");
359 }
360
361 Status Visit(const HalfFloatType&) {
362 return Status::NotImplemented("No product implemented");
363 }
364
365 Status Visit(const BooleanType&) {
366 auto ty = TypeTraits<typename ProductImpl<BooleanType>::AccType>::type_singleton();
367 state.reset(new ProductImpl<BooleanType>(ty, options));
368 return Status::OK();
369 }
370
371 template <typename Type>
372 enable_if_number<Type, Status> Visit(const Type&) {
373 auto ty = TypeTraits<typename ProductImpl<Type>::AccType>::type_singleton();
374 state.reset(new ProductImpl<Type>(ty, options));
375 return Status::OK();
376 }
377
378 template <typename Type>
379 enable_if_decimal<Type, Status> Visit(const Type&) {
380 state.reset(new ProductImpl<Type>(type, options));
381 return Status::OK();
382 }
383
384 Result<std::unique_ptr<KernelState>> Create() {
385 RETURN_NOT_OK(VisitTypeInline(*type, this));
386 return std::move(state);
387 }
388
389 static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
390 const KernelInitArgs& args) {
391 ProductInit visitor(ctx, args.inputs[0].type,
392 static_cast<const ScalarAggregateOptions&>(*args.options));
393 return visitor.Create();
394 }
395 };
396
397 // ----------------------------------------------------------------------
398 // MinMax implementation
399
400 Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx,
401 const KernelInitArgs& args) {
402 ARROW_ASSIGN_OR_RAISE(auto out_type,
403 args.kernel->signature->out_type().Resolve(ctx, args.inputs));
404 MinMaxInitState<SimdLevel::NONE> visitor(
405 ctx, *args.inputs[0].type, std::move(out_type.type),
406 static_cast<const ScalarAggregateOptions&>(*args.options));
407 return visitor.Create();
408 }
409
410 // For "min" and "max" functions: override finalize and return the actual value
411 template <MinOrMax min_or_max>
412 void AddMinOrMaxAggKernel(ScalarAggregateFunction* func,
413 ScalarAggregateFunction* min_max_func) {
414 auto sig = KernelSignature::Make(
415 {InputType(ValueDescr::ANY)},
416 OutputType([](KernelContext*,
417 const std::vector<ValueDescr>& descrs) -> Result<ValueDescr> {
418 // any[T] -> scalar[T]
419 return ValueDescr::Scalar(descrs.front().type);
420 }));
421
422 auto init = [min_max_func](
423 KernelContext* ctx,
424 const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
425 std::vector<ValueDescr> inputs = args.inputs;
426 ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs));
427 KernelInitArgs new_args{kernel, inputs, args.options};
428 return kernel->init(ctx, new_args);
429 };
430
431 auto finalize = [](KernelContext* ctx, Datum* out) -> Status {
432 Datum temp;
433 RETURN_NOT_OK(checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, &temp));
434 const auto& result = temp.scalar_as<StructScalar>();
435 DCHECK(result.is_valid);
436 *out = result.value[static_cast<uint8_t>(min_or_max)];
437 return Status::OK();
438 };
439
440 // Note SIMD level is always NONE, but the convenience kernel will
441 // dispatch to an appropriate implementation
442 AddAggKernel(std::move(sig), std::move(init), std::move(finalize), func);
443 }
444
445 // ----------------------------------------------------------------------
446 // Any implementation
447
448 struct BooleanAnyImpl : public ScalarAggregator {
449 explicit BooleanAnyImpl(ScalarAggregateOptions options) : options(std::move(options)) {}
450
451 Status Consume(KernelContext*, const ExecBatch& batch) override {
452 // short-circuit if seen a True already
453 if (this->any == true && this->count >= options.min_count) {
454 return Status::OK();
455 }
456 if (batch[0].is_scalar()) {
457 const auto& scalar = *batch[0].scalar();
458 this->has_nulls = !scalar.is_valid;
459 this->any = scalar.is_valid && checked_cast<const BooleanScalar&>(scalar).value;
460 this->count += scalar.is_valid;
461 return Status::OK();
462 }
463 const auto& data = *batch[0].array();
464 this->has_nulls = data.GetNullCount() > 0;
465 this->count += data.length - data.GetNullCount();
466 arrow::internal::OptionalBinaryBitBlockCounter counter(
467 data.buffers[0], data.offset, data.buffers[1], data.offset, data.length);
468 int64_t position = 0;
469 while (position < data.length) {
470 const auto block = counter.NextAndBlock();
471 if (block.popcount > 0) {
472 this->any = true;
473 break;
474 }
475 position += block.length;
476 }
477 return Status::OK();
478 }
479
480 Status MergeFrom(KernelContext*, KernelState&& src) override {
481 const auto& other = checked_cast<const BooleanAnyImpl&>(src);
482 this->any |= other.any;
483 this->has_nulls |= other.has_nulls;
484 this->count += other.count;
485 return Status::OK();
486 }
487
488 Status Finalize(KernelContext* ctx, Datum* out) override {
489 if ((!options.skip_nulls && !this->any && this->has_nulls) ||
490 this->count < options.min_count) {
491 out->value = std::make_shared<BooleanScalar>();
492 } else {
493 out->value = std::make_shared<BooleanScalar>(this->any);
494 }
495 return Status::OK();
496 }
497
498 bool any = false;
499 bool has_nulls = false;
500 int64_t count = 0;
501 ScalarAggregateOptions options;
502 };
503
504 Result<std::unique_ptr<KernelState>> AnyInit(KernelContext*, const KernelInitArgs& args) {
505 const ScalarAggregateOptions options =
506 static_cast<const ScalarAggregateOptions&>(*args.options);
507 return ::arrow::internal::make_unique<BooleanAnyImpl>(
508 static_cast<const ScalarAggregateOptions&>(*args.options));
509 }
510
511 // ----------------------------------------------------------------------
512 // All implementation
513
514 struct BooleanAllImpl : public ScalarAggregator {
515 explicit BooleanAllImpl(ScalarAggregateOptions options) : options(std::move(options)) {}
516
517 Status Consume(KernelContext*, const ExecBatch& batch) override {
518 // short-circuit if seen a false already
519 if (this->all == false && this->count >= options.min_count) {
520 return Status::OK();
521 }
522 // short-circuit if seen a null already
523 if (!options.skip_nulls && this->has_nulls) {
524 return Status::OK();
525 }
526 if (batch[0].is_scalar()) {
527 const auto& scalar = *batch[0].scalar();
528 this->has_nulls = !scalar.is_valid;
529 this->count += scalar.is_valid;
530 this->all = !scalar.is_valid || checked_cast<const BooleanScalar&>(scalar).value;
531 return Status::OK();
532 }
533 const auto& data = *batch[0].array();
534 this->has_nulls = data.GetNullCount() > 0;
535 this->count += data.length - data.GetNullCount();
536 arrow::internal::OptionalBinaryBitBlockCounter counter(
537 data.buffers[1], data.offset, data.buffers[0], data.offset, data.length);
538 int64_t position = 0;
539 while (position < data.length) {
540 const auto block = counter.NextOrNotBlock();
541 if (!block.AllSet()) {
542 this->all = false;
543 break;
544 }
545 position += block.length;
546 }
547
548 return Status::OK();
549 }
550
551 Status MergeFrom(KernelContext*, KernelState&& src) override {
552 const auto& other = checked_cast<const BooleanAllImpl&>(src);
553 this->all &= other.all;
554 this->has_nulls |= other.has_nulls;
555 this->count += other.count;
556 return Status::OK();
557 }
558
559 Status Finalize(KernelContext*, Datum* out) override {
560 if ((!options.skip_nulls && this->all && this->has_nulls) ||
561 this->count < options.min_count) {
562 out->value = std::make_shared<BooleanScalar>();
563 } else {
564 out->value = std::make_shared<BooleanScalar>(this->all);
565 }
566 return Status::OK();
567 }
568
569 bool all = true;
570 bool has_nulls = false;
571 int64_t count = 0;
572 ScalarAggregateOptions options;
573 };
574
575 Result<std::unique_ptr<KernelState>> AllInit(KernelContext*, const KernelInitArgs& args) {
576 return ::arrow::internal::make_unique<BooleanAllImpl>(
577 static_cast<const ScalarAggregateOptions&>(*args.options));
578 }
579
580 // ----------------------------------------------------------------------
581 // Index implementation
582
583 template <typename ArgType>
584 struct IndexImpl : public ScalarAggregator {
585 using ArgValue = typename internal::GetViewType<ArgType>::T;
586
587 explicit IndexImpl(IndexOptions options, KernelState* raw_state)
588 : options(std::move(options)), seen(0), index(-1) {
589 if (auto state = static_cast<IndexImpl<ArgType>*>(raw_state)) {
590 seen = state->seen;
591 index = state->index;
592 }
593 }
594
595 Status Consume(KernelContext* ctx, const ExecBatch& batch) override {
596 // short-circuit
597 if (index >= 0 || !options.value->is_valid) {
598 return Status::OK();
599 }
600
601 const ArgValue desired = internal::UnboxScalar<ArgType>::Unbox(*options.value);
602
603 if (batch[0].is_scalar()) {
604 seen = batch.length;
605 if (batch[0].scalar()->is_valid) {
606 const ArgValue v = internal::UnboxScalar<ArgType>::Unbox(*batch[0].scalar());
607 if (v == desired) {
608 index = 0;
609 return Status::Cancelled("Found");
610 }
611 }
612 return Status::OK();
613 }
614
615 auto input = batch[0].array();
616 seen = input->length;
617 int64_t i = 0;
618
619 ARROW_UNUSED(internal::VisitArrayValuesInline<ArgType>(
620 *input,
621 [&](ArgValue v) -> Status {
622 if (v == desired) {
623 index = i;
624 return Status::Cancelled("Found");
625 } else {
626 ++i;
627 return Status::OK();
628 }
629 },
630 [&]() -> Status {
631 ++i;
632 return Status::OK();
633 }));
634
635 return Status::OK();
636 }
637
638 Status MergeFrom(KernelContext*, KernelState&& src) override {
639 const auto& other = checked_cast<const IndexImpl&>(src);
640 if (index < 0 && other.index >= 0) {
641 index = seen + other.index;
642 }
643 seen += other.seen;
644 return Status::OK();
645 }
646
647 Status Finalize(KernelContext*, Datum* out) override {
648 out->value = std::make_shared<Int64Scalar>(index >= 0 ? index : -1);
649 return Status::OK();
650 }
651
652 const IndexOptions options;
653 int64_t seen = 0;
654 int64_t index = -1;
655 };
656
657 struct IndexInit {
658 std::unique_ptr<KernelState> state;
659 KernelContext* ctx;
660 const IndexOptions& options;
661 const DataType& type;
662
663 IndexInit(KernelContext* ctx, const IndexOptions& options, const DataType& type)
664 : ctx(ctx), options(options), type(type) {}
665
666 Status Visit(const DataType& type) {
667 return Status::NotImplemented("Index kernel not implemented for ", type.ToString());
668 }
669
670 Status Visit(const BooleanType&) {
671 state.reset(new IndexImpl<BooleanType>(options, ctx->state()));
672 return Status::OK();
673 }
674
675 template <typename Type>
676 enable_if_number<Type, Status> Visit(const Type&) {
677 state.reset(new IndexImpl<Type>(options, ctx->state()));
678 return Status::OK();
679 }
680
681 template <typename Type>
682 enable_if_base_binary<Type, Status> Visit(const Type&) {
683 state.reset(new IndexImpl<Type>(options, ctx->state()));
684 return Status::OK();
685 }
686
687 template <typename Type>
688 enable_if_date<Type, Status> Visit(const Type&) {
689 state.reset(new IndexImpl<Type>(options, ctx->state()));
690 return Status::OK();
691 }
692
693 template <typename Type>
694 enable_if_time<Type, Status> Visit(const Type&) {
695 state.reset(new IndexImpl<Type>(options, ctx->state()));
696 return Status::OK();
697 }
698
699 template <typename Type>
700 enable_if_timestamp<Type, Status> Visit(const Type&) {
701 state.reset(new IndexImpl<Type>(options, ctx->state()));
702 return Status::OK();
703 }
704
705 Result<std::unique_ptr<KernelState>> Create() {
706 RETURN_NOT_OK(VisitTypeInline(type, this));
707 return std::move(state);
708 }
709
710 static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
711 const KernelInitArgs& args) {
712 if (!args.options) {
713 return Status::Invalid("Must provide IndexOptions for index kernel");
714 }
715 IndexInit visitor(ctx, static_cast<const IndexOptions&>(*args.options),
716 *args.inputs[0].type);
717 return visitor.Create();
718 }
719 };
720
721 } // namespace
722
723 void AddBasicAggKernels(KernelInit init,
724 const std::vector<std::shared_ptr<DataType>>& types,
725 std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func,
726 SimdLevel::type simd_level) {
727 for (const auto& ty : types) {
728 // array[InT] -> scalar[OutT]
729 auto sig =
730 KernelSignature::Make({InputType::Array(ty->id())}, ValueDescr::Scalar(out_ty));
731 AddAggKernel(std::move(sig), init, func, simd_level);
732 }
733 }
734
735 void AddScalarAggKernels(KernelInit init,
736 const std::vector<std::shared_ptr<DataType>>& types,
737 std::shared_ptr<DataType> out_ty,
738 ScalarAggregateFunction* func) {
739 for (const auto& ty : types) {
740 // scalar[InT] -> scalar[OutT]
741 auto sig =
742 KernelSignature::Make({InputType::Scalar(ty->id())}, ValueDescr::Scalar(out_ty));
743 AddAggKernel(std::move(sig), init, func, SimdLevel::NONE);
744 }
745 }
746
747 void AddArrayScalarAggKernels(KernelInit init,
748 const std::vector<std::shared_ptr<DataType>>& types,
749 std::shared_ptr<DataType> out_ty,
750 ScalarAggregateFunction* func,
751 SimdLevel::type simd_level = SimdLevel::NONE) {
752 AddBasicAggKernels(init, types, out_ty, func, simd_level);
753 AddScalarAggKernels(init, types, out_ty, func);
754 }
755
756 namespace {
757
758 Result<ValueDescr> MinMaxType(KernelContext*, const std::vector<ValueDescr>& descrs) {
759 // any[T] -> scalar[struct<min: T, max: T>]
760 auto ty = descrs.front().type;
761 return ValueDescr::Scalar(struct_({field("min", ty), field("max", ty)}));
762 }
763
764 } // namespace
765
766 void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id,
767 ScalarAggregateFunction* func, SimdLevel::type simd_level) {
768 auto sig = KernelSignature::Make({InputType(get_id.id)}, OutputType(MinMaxType));
769 AddAggKernel(std::move(sig), init, func, simd_level);
770 }
771
772 void AddMinMaxKernels(KernelInit init,
773 const std::vector<std::shared_ptr<DataType>>& types,
774 ScalarAggregateFunction* func, SimdLevel::type simd_level) {
775 for (const auto& ty : types) {
776 AddMinMaxKernel(init, ty, func, simd_level);
777 }
778 }
779
780 namespace {
781
782 Result<ValueDescr> ScalarFirstType(KernelContext*,
783 const std::vector<ValueDescr>& descrs) {
784 ValueDescr result = descrs.front();
785 result.shape = ValueDescr::SCALAR;
786 return result;
787 }
788
789 const FunctionDoc count_doc{"Count the number of null / non-null values",
790 ("By default, only non-null values are counted.\n"
791 "This can be changed through CountOptions."),
792 {"array"},
793 "CountOptions"};
794
795 const FunctionDoc count_distinct_doc{"Count the number of unique values",
796 ("By default, only non-null values are counted.\n"
797 "This can be changed through CountOptions."),
798 {"array"},
799 "CountOptions"};
800
801 const FunctionDoc sum_doc{
802 "Compute the sum of a numeric array",
803 ("Null values are ignored by default. Minimum count of non-null\n"
804 "values can be set and null is returned if too few are present.\n"
805 "This can be changed through ScalarAggregateOptions."),
806 {"array"},
807 "ScalarAggregateOptions"};
808
809 const FunctionDoc product_doc{
810 "Compute the product of values in a numeric array",
811 ("Null values are ignored by default. Minimum count of non-null\n"
812 "values can be set and null is returned if too few are present.\n"
813 "This can be changed through ScalarAggregateOptions."),
814 {"array"},
815 "ScalarAggregateOptions"};
816
817 const FunctionDoc mean_doc{
818 "Compute the mean of a numeric array",
819 ("Null values are ignored by default. Minimum count of non-null\n"
820 "values can be set and null is returned if too few are present.\n"
821 "This can be changed through ScalarAggregateOptions.\n"
822 "The result is a double for integer and floating point arguments,\n"
823 "and a decimal with the same bit-width/precision/scale for decimal arguments.\n"
824 "For integers and floats, NaN is returned if min_count = 0 and\n"
825 "there are no values. For decimals, null is returned instead."),
826 {"array"},
827 "ScalarAggregateOptions"};
828
829 const FunctionDoc min_max_doc{"Compute the minimum and maximum values of a numeric array",
830 ("Null values are ignored by default.\n"
831 "This can be changed through ScalarAggregateOptions."),
832 {"array"},
833 "ScalarAggregateOptions"};
834
835 const FunctionDoc min_or_max_doc{
836 "Compute the minimum or maximum values of a numeric array",
837 ("Null values are ignored by default.\n"
838 "This can be changed through ScalarAggregateOptions."),
839 {"array"},
840 "ScalarAggregateOptions"};
841
842 const FunctionDoc any_doc{"Test whether any element in a boolean array evaluates to true",
843 ("Null values are ignored by default.\n"
844 "If null values are taken into account by setting "
845 "ScalarAggregateOptions parameter skip_nulls = false then "
846 "Kleene logic is used.\n"
847 "See KleeneOr for more details on Kleene logic."),
848 {"array"},
849 "ScalarAggregateOptions"};
850
851 const FunctionDoc all_doc{"Test whether all elements in a boolean array evaluate to true",
852 ("Null values are ignored by default.\n"
853 "If null values are taken into account by setting "
854 "ScalarAggregateOptions parameter skip_nulls = false then "
855 "Kleene logic is used.\n"
856 "See KleeneAnd for more details on Kleene logic."),
857 {"array"},
858 "ScalarAggregateOptions"};
859
860 const FunctionDoc index_doc{"Find the index of the first occurrence of a given value",
861 ("The result is always computed as an int64_t, regardless\n"
862 "of the offset type of the input array."),
863 {"array"},
864 "IndexOptions"};
865
866 } // namespace
867
868 void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
869 static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
870 static auto default_count_options = CountOptions::Defaults();
871
872 auto func = std::make_shared<ScalarAggregateFunction>(
873 "count", Arity::Unary(), &count_doc, &default_count_options);
874
875 // Takes any input, outputs int64 scalar
876 InputType any_input;
877 AddAggKernel(KernelSignature::Make({any_input}, ValueDescr::Scalar(int64())), CountInit,
878 func.get());
879 DCHECK_OK(registry->AddFunction(std::move(func)));
880
881 func = std::make_shared<ScalarAggregateFunction>(
882 "count_distinct", Arity::Unary(), &count_distinct_doc, &default_count_options);
883 // Takes any input, outputs int64 scalar
884 AddCountDistinctKernels(func.get());
885 DCHECK_OK(registry->AddFunction(std::move(func)));
886
887 func = std::make_shared<ScalarAggregateFunction>("sum", Arity::Unary(), &sum_doc,
888 &default_scalar_aggregate_options);
889 AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get());
890 AddAggKernel(
891 KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)),
892 SumInit, func.get(), SimdLevel::NONE);
893 AddAggKernel(
894 KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)),
895 SumInit, func.get(), SimdLevel::NONE);
896 AddArrayScalarAggKernels(SumInit, SignedIntTypes(), int64(), func.get());
897 AddArrayScalarAggKernels(SumInit, UnsignedIntTypes(), uint64(), func.get());
898 AddArrayScalarAggKernels(SumInit, FloatingPointTypes(), float64(), func.get());
899 // Add the SIMD variants for sum
900 #if defined(ARROW_HAVE_RUNTIME_AVX2) || defined(ARROW_HAVE_RUNTIME_AVX512)
901 auto cpu_info = arrow::internal::CpuInfo::GetInstance();
902 #endif
903 #if defined(ARROW_HAVE_RUNTIME_AVX2)
904 if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
905 AddSumAvx2AggKernels(func.get());
906 }
907 #endif
908 #if defined(ARROW_HAVE_RUNTIME_AVX512)
909 if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) {
910 AddSumAvx512AggKernels(func.get());
911 }
912 #endif
913 DCHECK_OK(registry->AddFunction(std::move(func)));
914
915 func = std::make_shared<ScalarAggregateFunction>("mean", Arity::Unary(), &mean_doc,
916 &default_scalar_aggregate_options);
917 AddArrayScalarAggKernels(MeanInit, {boolean()}, float64(), func.get());
918 AddArrayScalarAggKernels(MeanInit, NumericTypes(), float64(), func.get());
919 AddAggKernel(
920 KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)),
921 MeanInit, func.get(), SimdLevel::NONE);
922 AddAggKernel(
923 KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)),
924 MeanInit, func.get(), SimdLevel::NONE);
925 // Add the SIMD variants for mean
926 #if defined(ARROW_HAVE_RUNTIME_AVX2)
927 if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
928 AddMeanAvx2AggKernels(func.get());
929 }
930 #endif
931 #if defined(ARROW_HAVE_RUNTIME_AVX512)
932 if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) {
933 AddMeanAvx512AggKernels(func.get());
934 }
935 #endif
936 DCHECK_OK(registry->AddFunction(std::move(func)));
937
938 func = std::make_shared<ScalarAggregateFunction>(
939 "min_max", Arity::Unary(), &min_max_doc, &default_scalar_aggregate_options);
940 AddMinMaxKernels(MinMaxInit, {null(), boolean()}, func.get());
941 AddMinMaxKernels(MinMaxInit, NumericTypes(), func.get());
942 AddMinMaxKernels(MinMaxInit, TemporalTypes(), func.get());
943 AddMinMaxKernels(MinMaxInit, BaseBinaryTypes(), func.get());
944 AddMinMaxKernel(MinMaxInit, Type::FIXED_SIZE_BINARY, func.get());
945 AddMinMaxKernel(MinMaxInit, Type::INTERVAL_MONTHS, func.get());
946 AddMinMaxKernel(MinMaxInit, Type::DECIMAL128, func.get());
947 AddMinMaxKernel(MinMaxInit, Type::DECIMAL256, func.get());
948 // Add the SIMD variants for min max
949 #if defined(ARROW_HAVE_RUNTIME_AVX2)
950 if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
951 AddMinMaxAvx2AggKernels(func.get());
952 }
953 #endif
954 #if defined(ARROW_HAVE_RUNTIME_AVX512)
955 if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) {
956 AddMinMaxAvx512AggKernels(func.get());
957 }
958 #endif
959
960 auto min_max_func = func.get();
961 DCHECK_OK(registry->AddFunction(std::move(func)));
962
963 // Add min/max as convenience functions
964 func = std::make_shared<ScalarAggregateFunction>("min", Arity::Unary(), &min_or_max_doc,
965 &default_scalar_aggregate_options);
966 AddMinOrMaxAggKernel<MinOrMax::Min>(func.get(), min_max_func);
967 DCHECK_OK(registry->AddFunction(std::move(func)));
968
969 func = std::make_shared<ScalarAggregateFunction>("max", Arity::Unary(), &min_or_max_doc,
970 &default_scalar_aggregate_options);
971 AddMinOrMaxAggKernel<MinOrMax::Max>(func.get(), min_max_func);
972 DCHECK_OK(registry->AddFunction(std::move(func)));
973
974 func = std::make_shared<ScalarAggregateFunction>(
975 "product", Arity::Unary(), &product_doc, &default_scalar_aggregate_options);
976 AddArrayScalarAggKernels(ProductInit::Init, {boolean()}, uint64(), func.get());
977 AddArrayScalarAggKernels(ProductInit::Init, SignedIntTypes(), int64(), func.get());
978 AddArrayScalarAggKernels(ProductInit::Init, UnsignedIntTypes(), uint64(), func.get());
979 AddArrayScalarAggKernels(ProductInit::Init, FloatingPointTypes(), float64(),
980 func.get());
981 AddAggKernel(
982 KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)),
983 ProductInit::Init, func.get(), SimdLevel::NONE);
984 AddAggKernel(
985 KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)),
986 ProductInit::Init, func.get(), SimdLevel::NONE);
987 DCHECK_OK(registry->AddFunction(std::move(func)));
988
989 // any
990 func = std::make_shared<ScalarAggregateFunction>("any", Arity::Unary(), &any_doc,
991 &default_scalar_aggregate_options);
992 AddArrayScalarAggKernels(AnyInit, {boolean()}, boolean(), func.get());
993 DCHECK_OK(registry->AddFunction(std::move(func)));
994
995 // all
996 func = std::make_shared<ScalarAggregateFunction>("all", Arity::Unary(), &all_doc,
997 &default_scalar_aggregate_options);
998 AddArrayScalarAggKernels(AllInit, {boolean()}, boolean(), func.get());
999 DCHECK_OK(registry->AddFunction(std::move(func)));
1000
1001 // index
1002 func = std::make_shared<ScalarAggregateFunction>("index", Arity::Unary(), &index_doc);
1003 AddBasicAggKernels(IndexInit::Init, BaseBinaryTypes(), int64(), func.get());
1004 AddBasicAggKernels(IndexInit::Init, PrimitiveTypes(), int64(), func.get());
1005 AddBasicAggKernels(IndexInit::Init, TemporalTypes(), int64(), func.get());
1006 DCHECK_OK(registry->AddFunction(std::move(func)));
1007 }
1008
1009 } // namespace internal
1010 } // namespace compute
1011 } // namespace arrow