]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std_internal.h
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / kernels / aggregate_var_std_internal.h
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #pragma once
19
20 #include "arrow/util/int128_internal.h"
21
22 namespace arrow {
23 namespace compute {
24 namespace internal {
25
26 using arrow::internal::int128_t;
27
28 // Accumulate sum/squared sum (using naive summation)
29 // Shared implementation between scalar/hash aggregate variance/stddev kernels
30 template <typename ArrowType>
31 struct IntegerVarStd {
32 using c_type = typename ArrowType::c_type;
33
34 int64_t count = 0;
35 int64_t sum = 0;
36 int128_t square_sum = 0;
37
38 void ConsumeOne(const c_type value) {
39 sum += value;
40 square_sum += static_cast<uint64_t>(value) * value;
41 count++;
42 }
43
44 double mean() const { return static_cast<double>(sum) / count; }
45
46 double m2() const {
47 // calculate m2 = square_sum - sum * sum / count
48 // decompose `sum * sum / count` into integers and fractions
49 const int128_t sum_square = static_cast<int128_t>(sum) * sum;
50 const int128_t integers = sum_square / count;
51 const double fractions = static_cast<double>(sum_square % count) / count;
52 return static_cast<double>(square_sum - integers) - fractions;
53 }
54 };
55
56 static inline void MergeVarStd(int64_t count1, double mean1, int64_t count2, double mean2,
57 double m22, int64_t* out_count, double* out_mean,
58 double* out_m2) {
59 double mean = (mean1 * count1 + mean2 * count2) / (count1 + count2);
60 *out_m2 += m22 + count1 * (mean1 - mean) * (mean1 - mean) +
61 count2 * (mean2 - mean) * (mean2 - mean);
62 *out_count += count2;
63 *out_mean = mean;
64 }
65
66 } // namespace internal
67 } // namespace compute
68 } // namespace arrow